From 4db619457f56231058001f17ba2616b9e5c65571 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 9 Sep 2025 10:31:59 +0200 Subject: [PATCH 01/35] setup --- setup.py | 51 ++----------------- src/transformers/dependency_versions_table.py | 7 --- 2 files changed, 5 insertions(+), 53 deletions(-) diff --git a/setup.py b/setup.py index 9f3bb1750597..f0f1598f4971 100644 --- a/setup.py +++ b/setup.py @@ -122,9 +122,6 @@ "jaxlib>=0.4.1,<=0.4.13", "jinja2>=3.1.0", "kenlm", - # Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support. - "keras>2.9,<2.16", - "keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras. "kernels>=0.6.1,<=0.9", "librosa", "natten>=0.14.6,<0.15.0", @@ -177,12 +174,6 @@ "sudachipy>=0.6.6", "sudachidict_core>=20220729", "tensorboard", - # TensorFlow pin. When changing this value, update examples/tensorflow/_tests_requirements.txt accordingly - "tensorflow-cpu>2.9,<2.16", - "tensorflow>2.9,<2.16", - "tensorflow-text<2.16", - "tensorflow-probability<0.24", - "tf2onnx", "timeout-decorator", "tiktoken", "timm<=1.0.19,!=1.0.18", @@ -273,32 +264,19 @@ def run(self): extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "rhoknp") extras["sklearn"] = deps_list("scikit-learn") -extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp") -extras["tf-cpu"] = deps_list( - "keras", - "tensorflow-cpu", - "onnxconverter-common", - "tf2onnx", - "tensorflow-text", - "keras-nlp", - "tensorflow-probability", -) - extras["torch"] = deps_list("torch", "accelerate") extras["accelerate"] = deps_list("accelerate") extras["hf_xet"] = deps_list("hf_xet") if os.name == "nt": # windows extras["retrieval"] = deps_list("datasets") # faiss is not supported on windows - extras["flax"] = [] # jax is not supported on windows else: extras["retrieval"] = deps_list("faiss-cpu", "datasets") - extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax", "scipy") extras["tokenizers"] = deps_list("tokenizers") extras["ftfy"] = deps_list("ftfy") extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools") -extras["onnx"] = deps_list("onnxconverter-common", "tf2onnx") + extras["onnxruntime"] +extras["onnx"] = deps_list("onnxconverter-common") + extras["onnxruntime"] extras["modelcreation"] = deps_list("cookiecutter") extras["sagemaker"] = deps_list("sagemaker") @@ -320,8 +298,6 @@ def run(self): # `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead extras["speech"] = deps_list("torchaudio") + extras["audio"] extras["torch-speech"] = deps_list("torchaudio") + extras["audio"] -extras["tf-speech"] = extras["audio"] -extras["flax-speech"] = extras["audio"] extras["vision"] = deps_list("Pillow") extras["timm"] = deps_list("timm") extras["torch-vision"] = deps_list("torchvision") + extras["vision"] @@ -372,9 +348,7 @@ def run(self): extras["quality"] = deps_list("datasets", "ruff", "GitPython", "urllib3", "libcst", "rich", "pandas") extras["all"] = ( - extras["tf"] - + extras["torch"] - + extras["flax"] + extras["torch"] + extras["sentencepiece"] + extras["tokenizers"] + extras["torch-speech"] @@ -409,18 +383,7 @@ def run(self): + extras["onnxruntime"] + extras["num2words"] ) -extras["dev-tensorflow"] = ( - extras["testing"] - + extras["tf"] - + extras["sentencepiece"] - + extras["tokenizers"] - + extras["vision"] - + extras["quality"] - + extras["sklearn"] - + extras["modelcreation"] - + extras["onnx"] - + extras["tf-speech"] -) + extras["dev"] = ( extras["all"] + extras["testing"] + extras["quality"] + extras["ja"] + extras["sklearn"] + extras["modelcreation"] ) @@ -464,10 +427,10 @@ def run(self): version="4.57.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)", author_email="transformers@huggingface.co", - description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow", + description="Transformers: the model-definition framework for state-of-the-art machine learning models in text, vision, audio, and multimodal models, for both inference and training.", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", - keywords="NLP vision speech deep learning transformer pytorch tensorflow jax BERT GPT-2 Wav2Vec2 ViT", + keywords="machine-learning nlp python pytorch transformer llm vlm deep-learning inference training model-hub pretrained-models llama gemma qwen", license="Apache 2.0 License", url="https://github.com/huggingface/transformers", package_dir={"": "src"}, @@ -503,14 +466,10 @@ def run(self): ) extras["tests_torch"] = deps_list() -extras["tests_tf"] = deps_list() -extras["tests_flax"] = deps_list() extras["tests_hub"] = deps_list() extras["tests_pipelines_torch"] = deps_list() -extras["tests_pipelines_tf"] = deps_list() extras["tests_onnx"] = deps_list() extras["tests_examples_torch"] = deps_list() -extras["tests_examples_tf"] = deps_list() extras["tests_custom_tokenizers"] = deps_list() extras["tests_exotic_models"] = deps_list() extras["consistency"] = deps_list() diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index ab6e747d14db..b9b70d256e8b 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -31,8 +31,6 @@ "jaxlib": "jaxlib>=0.4.1,<=0.4.13", "jinja2": "jinja2>=3.1.0", "kenlm": "kenlm", - "keras": "keras>2.9,<2.16", - "keras-nlp": "keras-nlp>=0.3.1,<0.14.0", "kernels": "kernels>=0.6.1,<=0.9", "librosa": "librosa", "natten": "natten>=0.14.6,<0.15.0", @@ -82,11 +80,6 @@ "sudachipy": "sudachipy>=0.6.6", "sudachidict_core": "sudachidict_core>=20220729", "tensorboard": "tensorboard", - "tensorflow-cpu": "tensorflow-cpu>2.9,<2.16", - "tensorflow": "tensorflow>2.9,<2.16", - "tensorflow-text": "tensorflow-text<2.16", - "tensorflow-probability": "tensorflow-probability<0.24", - "tf2onnx": "tf2onnx", "timeout-decorator": "timeout-decorator", "tiktoken": "tiktoken", "timm": "timm<=1.0.19,!=1.0.18", From 2b4bf6cc3e255ec0377297d7bb1ea446b366e41e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 9 Sep 2025 10:55:37 +0200 Subject: [PATCH 02/35] start the purge --- docs/source/ar/tflite.md | 40 - docs/source/en/tflite.md | 66 - docs/source/hi/tflite.md | 55 - docs/source/ja/perf_train_tpu_tf.md | 168 - docs/source/ja/tf_xla.md | 179 - docs/source/ja/tflite.md | 58 - docs/source/ko/tflite.md | 62 - docs/source/zh/tf_xla.md | 179 - docs/source/zh/tflite.md | 54 - src/transformers/activations_tf.py | 147 - ...nvert_tf_hub_seq_to_seq_bert_to_pytorch.py | 86 - .../generation/flax_logits_process.py | 544 --- src/transformers/generation/flax_utils.py | 1032 ----- .../generation/tf_logits_process.py | 600 --- src/transformers/generation/tf_utils.py | 3132 --------------- src/transformers/modeling_flax_outputs.py | 700 ---- .../modeling_flax_pytorch_utils.py | 491 --- src/transformers/modeling_flax_utils.py | 1274 ------ src/transformers/modeling_tf_outputs.py | 990 ----- src/transformers/modeling_tf_pytorch_utils.py | 676 ---- src/transformers/modeling_tf_utils.py | 3529 ----------------- ...lbert_original_tf_checkpoint_to_pytorch.py | 62 - .../models/albert/modeling_flax_albert.py | 1132 ------ .../models/albert/modeling_tf_albert.py | 1572 -------- .../models/align/convert_align_tf_to_hf.py | 389 -- .../models/auto/modeling_flax_auto.py | 413 -- .../models/auto/modeling_tf_auto.py | 776 ---- .../models/bart/modeling_flax_bart.py | 2006 ---------- .../models/bart/modeling_tf_bart.py | 1713 -------- .../models/beit/modeling_flax_beit.py | 956 ----- ...bert_original_tf2_checkpoint_to_pytorch.py | 246 -- ..._bert_original_tf_checkpoint_to_pytorch.py | 62 - ..._bert_pytorch_checkpoint_to_original_tf.py | 112 - ...ping_original_tf2_checkpoint_to_pytorch.py | 188 - .../models/bert/modeling_flax_bert.py | 1727 -------- .../models/bert/modeling_tf_bert.py | 2125 ---------- .../models/bert/tokenization_bert_tf.py | 259 -- ...gbird_original_tf_checkpoint_to_pytorch.py | 69 - .../models/big_bird/modeling_flax_big_bird.py | 2648 ------------- .../convert_bigbird_pegasus_tf_to_pytorch.py | 169 - .../blenderbot/modeling_flax_blenderbot.py | 1508 ------- .../blenderbot/modeling_tf_blenderbot.py | 1557 -------- .../modeling_flax_blenderbot_small.py | 1528 ------- .../modeling_tf_blenderbot_small.py | 1527 ------- .../models/blip/modeling_tf_blip.py | 1709 -------- .../models/blip/modeling_tf_blip_text.py | 1122 ------ .../models/bloom/modeling_flax_bloom.py | 737 ---- ..._byt5_original_tf_checkpoint_to_pytorch.py | 59 - .../models/camembert/modeling_tf_camembert.py | 1800 --------- ...anine_original_tf_checkpoint_to_pytorch.py | 65 - .../models/clip/modeling_flax_clip.py | 1306 ------ .../models/clip/modeling_tf_clip.py | 1460 ------- ...ginal_tf1_checkpoint_to_pytorch_and_tf2.py | 57 - .../models/convbert/modeling_tf_convbert.py | 1474 ------- .../models/convnext/modeling_tf_convnext.py | 667 ---- .../convnextv2/modeling_tf_convnextv2.py | 681 ---- .../models/ctrl/modeling_tf_ctrl.py | 920 ----- .../models/cvt/modeling_tf_cvt.py | 1095 ----- .../data2vec/modeling_tf_data2vec_vision.py | 1723 -------- .../models/deberta/modeling_tf_deberta.py | 1652 -------- .../deberta_v2/modeling_tf_deberta_v2.py | 1879 --------- .../models/deit/modeling_tf_deit.py | 1232 ------ .../modeling_tf_efficientformer.py | 1198 ------ ...convert_gptsan_tf_checkpoint_to_pytorch.py | 181 - ...fo_xl_original_tf_checkpoint_to_pytorch.py | 121 - .../transfo_xl/modeling_tf_transfo_xl.py | 1128 ------ .../modeling_tf_transfo_xl_utilities.py | 178 - .../models/dinov2/modeling_flax_dinov2.py | 801 ---- .../distilbert/modeling_flax_distilbert.py | 906 ----- .../distilbert/modeling_tf_distilbert.py | 1146 ------ .../models/dpr/modeling_tf_dpr.py | 799 ---- ...ectra_original_tf_checkpoint_to_pytorch.py | 79 - .../models/electra/modeling_flax_electra.py | 1614 -------- .../models/electra/modeling_tf_electra.py | 1775 --------- .../modeling_flax_encoder_decoder.py | 901 ----- .../modeling_tf_encoder_decoder.py | 661 --- .../models/esm/modeling_tf_esm.py | 1574 -------- .../models/flaubert/modeling_tf_flaubert.py | 1343 ------- ...net_original_flax_checkpoint_to_pytorch.py | 156 - ...unnel_original_tf_checkpoint_to_pytorch.py | 64 - .../models/funnel/modeling_tf_funnel.py | 1883 --------- .../models/gemma/modeling_flax_gemma.py | 777 ---- ..._gpt2_original_tf_checkpoint_to_pytorch.py | 68 - .../models/gpt2/modeling_flax_gpt2.py | 782 ---- .../models/gpt2/modeling_tf_gpt2.py | 1238 ------ .../models/gpt2/tokenization_gpt2_tf.py | 119 - .../convert_gpt_neo_mesh_tf_to_pytorch.py | 71 - .../models/gpt_neo/modeling_flax_gpt_neo.py | 687 ---- .../models/gptj/modeling_flax_gptj.py | 721 ---- .../models/gptj/modeling_tf_gptj.py | 1094 ----- .../models/groupvit/modeling_tf_groupvit.py | 2141 ---------- .../models/hubert/modeling_tf_hubert.py | 1671 -------- .../models/idefics/modeling_tf_idefics.py | 1778 --------- .../models/idefics/perceiver_tf.py | 195 - src/transformers/models/idefics/vision_tf.py | 572 --- ...onvert_imagegpt_original_tf2_to_pytorch.py | 71 - .../models/layoutlm/modeling_tf_layoutlm.py | 1691 -------- .../layoutlmv3/modeling_tf_layoutlmv3.py | 1767 --------- .../models/led/modeling_tf_led.py | 2663 ------------- .../models/llama/modeling_flax_llama.py | 747 ---- .../longformer/modeling_tf_longformer.py | 2783 ------------- .../models/longt5/modeling_flax_longt5.py | 2449 ------------ ...xmert_original_tf_checkpoint_to_pytorch.py | 59 - .../models/lxmert/modeling_tf_lxmert.py | 1660 -------- .../models/marian/modeling_flax_marian.py | 1500 ------- .../models/marian/modeling_tf_marian.py | 1558 -------- .../models/mbart/modeling_flax_mbart.py | 1780 --------- .../models/mbart/modeling_tf_mbart.py | 1572 -------- .../models/mistral/modeling_flax_mistral.py | 744 ---- .../models/mistral/modeling_tf_mistral.py | 1016 ----- ...ebert_original_tf_checkpoint_to_pytorch.py | 58 - .../mobilebert/modeling_tf_mobilebert.py | 1979 --------- ...nvert_original_tf_checkpoint_to_pytorch.py | 141 - ...nvert_original_tf_checkpoint_to_pytorch.py | 177 - .../models/mobilevit/modeling_tf_mobilevit.py | 1376 ------- .../models/mpnet/modeling_tf_mpnet.py | 1353 ------- .../models/mt5/modeling_flax_mt5.py | 123 - .../models/mt5/modeling_tf_mt5.py | 98 - ..._myt5_original_tf_checkpoint_to_pytorch.py | 60 - ...penai_original_tf_checkpoint_to_pytorch.py | 74 - .../models/openai/modeling_tf_openai.py | 936 ----- .../models/opt/modeling_flax_opt.py | 802 ---- .../models/opt/modeling_tf_opt.py | 1092 ----- .../convert_owlvit_original_flax_to_hf.py | 406 -- .../pegasus/convert_pegasus_tf_to_pytorch.py | 130 - .../models/pegasus/modeling_flax_pegasus.py | 1532 ------- .../models/pegasus/modeling_tf_pegasus.py | 1573 -------- .../models/rag/modeling_tf_rag.py | 1776 --------- .../models/regnet/modeling_flax_regnet.py | 822 ---- .../models/regnet/modeling_tf_regnet.py | 611 --- ...onvert_rembert_tf_checkpoint_to_pytorch.py | 62 - .../models/rembert/modeling_tf_rembert.py | 1720 -------- .../models/resnet/modeling_flax_resnet.py | 704 ---- .../models/resnet/modeling_tf_resnet.py | 596 --- .../models/roberta/modeling_flax_roberta.py | 1500 ------- .../models/roberta/modeling_tf_roberta.py | 1782 --------- .../modeling_flax_roberta_prelayernorm.py | 1527 ------- .../modeling_tf_roberta_prelayernorm.py | 1807 --------- ...ormer_original_tf_checkpoint_to_pytorch.py | 62 - .../models/roformer/modeling_flax_roformer.py | 1091 ----- .../models/roformer/modeling_tf_roformer.py | 1546 -------- .../models/sam/modeling_tf_sam.py | 1723 -------- .../models/segformer/modeling_tf_segformer.py | 1044 ----- .../modeling_flax_speech_encoder_decoder.py | 930 ----- .../convert_s2t_fairseq_to_tfms.py | 121 - .../modeling_tf_speech_to_text.py | 1600 -------- .../swiftformer/modeling_tf_swiftformer.py | 866 ---- .../models/swin/modeling_tf_swin.py | 1639 -------- ...ers_original_flax_checkpoint_to_pytorch.py | 203 - ...rt_t5_original_tf_checkpoint_to_pytorch.py | 59 - .../models/t5/modeling_flax_t5.py | 1801 --------- src/transformers/models/t5/modeling_tf_t5.py | 1676 -------- ...tapas_original_tf_checkpoint_to_pytorch.py | 137 - .../models/tapas/modeling_tf_tapas.py | 2461 ------------ .../modeling_flax_vision_encoder_decoder.py | 864 ---- .../modeling_tf_vision_encoder_decoder.py | 696 ---- .../modeling_flax_vision_text_dual_encoder.py | 601 --- .../modeling_tf_vision_text_dual_encoder.py | 623 --- .../models/vit/modeling_flax_vit.py | 677 ---- .../models/vit/modeling_tf_vit.py | 906 ----- .../models/vit_mae/modeling_tf_vit_mae.py | 1374 ------- .../vivit/convert_vivit_flax_to_pytorch.py | 231 -- .../models/wav2vec2/modeling_flax_wav2vec2.py | 1423 ------- .../models/wav2vec2/modeling_tf_wav2vec2.py | 1855 --------- .../models/whisper/modeling_flax_whisper.py | 1707 -------- .../models/whisper/modeling_tf_whisper.py | 1754 -------- .../models/xglm/modeling_flax_xglm.py | 803 ---- .../models/xglm/modeling_tf_xglm.py | 1002 ----- .../models/xlm/modeling_tf_xlm.py | 1356 ------- .../xlm_roberta/modeling_flax_xlm_roberta.py | 1511 ------- .../xlm_roberta/modeling_tf_xlm_roberta.py | 1790 --------- ...xlnet_original_tf_checkpoint_to_pytorch.py | 113 - .../models/xlnet/modeling_tf_xlnet.py | 1820 --------- src/transformers/optimization_tf.py | 378 -- src/transformers/tf_utils.py | 294 -- src/transformers/training_args_tf.py | 300 -- src/transformers/utils/dummy_flax_objects.py | 107 - src/transformers/utils/dummy_tf_objects.py | 178 - tests/sagemaker/scripts/tensorflow/run_tf.py | 104 - utils/check_tf_ops.py | 101 - 180 files changed, 171270 deletions(-) delete mode 100644 docs/source/ar/tflite.md delete mode 100644 docs/source/en/tflite.md delete mode 100644 docs/source/hi/tflite.md delete mode 100644 docs/source/ja/perf_train_tpu_tf.md delete mode 100644 docs/source/ja/tf_xla.md delete mode 100644 docs/source/ja/tflite.md delete mode 100644 docs/source/ko/tflite.md delete mode 100644 docs/source/zh/tf_xla.md delete mode 100644 docs/source/zh/tflite.md delete mode 100644 src/transformers/activations_tf.py delete mode 100755 src/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py delete mode 100644 src/transformers/generation/flax_logits_process.py delete mode 100644 src/transformers/generation/flax_utils.py delete mode 100644 src/transformers/generation/tf_logits_process.py delete mode 100644 src/transformers/generation/tf_utils.py delete mode 100644 src/transformers/modeling_flax_outputs.py delete mode 100644 src/transformers/modeling_flax_pytorch_utils.py delete mode 100644 src/transformers/modeling_flax_utils.py delete mode 100644 src/transformers/modeling_tf_outputs.py delete mode 100644 src/transformers/modeling_tf_pytorch_utils.py delete mode 100644 src/transformers/modeling_tf_utils.py delete mode 100644 src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/albert/modeling_flax_albert.py delete mode 100644 src/transformers/models/albert/modeling_tf_albert.py delete mode 100644 src/transformers/models/align/convert_align_tf_to_hf.py delete mode 100644 src/transformers/models/auto/modeling_flax_auto.py delete mode 100644 src/transformers/models/auto/modeling_tf_auto.py delete mode 100644 src/transformers/models/bart/modeling_flax_bart.py delete mode 100644 src/transformers/models/bart/modeling_tf_bart.py delete mode 100644 src/transformers/models/beit/modeling_flax_beit.py delete mode 100644 src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py delete mode 100755 src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py delete mode 100644 src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/bert/modeling_flax_bert.py delete mode 100644 src/transformers/models/bert/modeling_tf_bert.py delete mode 100644 src/transformers/models/bert/tokenization_bert_tf.py delete mode 100644 src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/big_bird/modeling_flax_big_bird.py delete mode 100644 src/transformers/models/bigbird_pegasus/convert_bigbird_pegasus_tf_to_pytorch.py delete mode 100644 src/transformers/models/blenderbot/modeling_flax_blenderbot.py delete mode 100644 src/transformers/models/blenderbot/modeling_tf_blenderbot.py delete mode 100644 src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py delete mode 100644 src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py delete mode 100644 src/transformers/models/blip/modeling_tf_blip.py delete mode 100644 src/transformers/models/blip/modeling_tf_blip_text.py delete mode 100644 src/transformers/models/bloom/modeling_flax_bloom.py delete mode 100755 src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/camembert/modeling_tf_camembert.py delete mode 100644 src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/clip/modeling_flax_clip.py delete mode 100644 src/transformers/models/clip/modeling_tf_clip.py delete mode 100644 src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py delete mode 100644 src/transformers/models/convbert/modeling_tf_convbert.py delete mode 100644 src/transformers/models/convnext/modeling_tf_convnext.py delete mode 100644 src/transformers/models/convnextv2/modeling_tf_convnextv2.py delete mode 100644 src/transformers/models/ctrl/modeling_tf_ctrl.py delete mode 100644 src/transformers/models/cvt/modeling_tf_cvt.py delete mode 100644 src/transformers/models/data2vec/modeling_tf_data2vec_vision.py delete mode 100644 src/transformers/models/deberta/modeling_tf_deberta.py delete mode 100644 src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py delete mode 100644 src/transformers/models/deit/modeling_tf_deit.py delete mode 100644 src/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py delete mode 100644 src/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py delete mode 100644 src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl_utilities.py delete mode 100644 src/transformers/models/dinov2/modeling_flax_dinov2.py delete mode 100644 src/transformers/models/distilbert/modeling_flax_distilbert.py delete mode 100644 src/transformers/models/distilbert/modeling_tf_distilbert.py delete mode 100644 src/transformers/models/dpr/modeling_tf_dpr.py delete mode 100644 src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/electra/modeling_flax_electra.py delete mode 100644 src/transformers/models/electra/modeling_tf_electra.py delete mode 100644 src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py delete mode 100644 src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py delete mode 100644 src/transformers/models/esm/modeling_tf_esm.py delete mode 100644 src/transformers/models/flaubert/modeling_tf_flaubert.py delete mode 100644 src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py delete mode 100755 src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/funnel/modeling_tf_funnel.py delete mode 100644 src/transformers/models/gemma/modeling_flax_gemma.py delete mode 100755 src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/gpt2/modeling_flax_gpt2.py delete mode 100644 src/transformers/models/gpt2/modeling_tf_gpt2.py delete mode 100644 src/transformers/models/gpt2/tokenization_gpt2_tf.py delete mode 100644 src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py delete mode 100644 src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py delete mode 100644 src/transformers/models/gptj/modeling_flax_gptj.py delete mode 100644 src/transformers/models/gptj/modeling_tf_gptj.py delete mode 100644 src/transformers/models/groupvit/modeling_tf_groupvit.py delete mode 100644 src/transformers/models/hubert/modeling_tf_hubert.py delete mode 100644 src/transformers/models/idefics/modeling_tf_idefics.py delete mode 100644 src/transformers/models/idefics/perceiver_tf.py delete mode 100644 src/transformers/models/idefics/vision_tf.py delete mode 100644 src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py delete mode 100644 src/transformers/models/layoutlm/modeling_tf_layoutlm.py delete mode 100644 src/transformers/models/layoutlmv3/modeling_tf_layoutlmv3.py delete mode 100644 src/transformers/models/led/modeling_tf_led.py delete mode 100644 src/transformers/models/llama/modeling_flax_llama.py delete mode 100644 src/transformers/models/longformer/modeling_tf_longformer.py delete mode 100644 src/transformers/models/longt5/modeling_flax_longt5.py delete mode 100755 src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/lxmert/modeling_tf_lxmert.py delete mode 100644 src/transformers/models/marian/modeling_flax_marian.py delete mode 100644 src/transformers/models/marian/modeling_tf_marian.py delete mode 100644 src/transformers/models/mbart/modeling_flax_mbart.py delete mode 100644 src/transformers/models/mbart/modeling_tf_mbart.py delete mode 100644 src/transformers/models/mistral/modeling_flax_mistral.py delete mode 100644 src/transformers/models/mistral/modeling_tf_mistral.py delete mode 100644 src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/mobilebert/modeling_tf_mobilebert.py delete mode 100644 src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/mobilevit/modeling_tf_mobilevit.py delete mode 100644 src/transformers/models/mpnet/modeling_tf_mpnet.py delete mode 100644 src/transformers/models/mt5/modeling_flax_mt5.py delete mode 100644 src/transformers/models/mt5/modeling_tf_mt5.py delete mode 100644 src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py delete mode 100755 src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/openai/modeling_tf_openai.py delete mode 100644 src/transformers/models/opt/modeling_flax_opt.py delete mode 100644 src/transformers/models/opt/modeling_tf_opt.py delete mode 100644 src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py delete mode 100644 src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py delete mode 100644 src/transformers/models/pegasus/modeling_flax_pegasus.py delete mode 100644 src/transformers/models/pegasus/modeling_tf_pegasus.py delete mode 100644 src/transformers/models/rag/modeling_tf_rag.py delete mode 100644 src/transformers/models/regnet/modeling_flax_regnet.py delete mode 100644 src/transformers/models/regnet/modeling_tf_regnet.py delete mode 100755 src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/rembert/modeling_tf_rembert.py delete mode 100644 src/transformers/models/resnet/modeling_flax_resnet.py delete mode 100644 src/transformers/models/resnet/modeling_tf_resnet.py delete mode 100644 src/transformers/models/roberta/modeling_flax_roberta.py delete mode 100644 src/transformers/models/roberta/modeling_tf_roberta.py delete mode 100644 src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py delete mode 100644 src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py delete mode 100755 src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/roformer/modeling_flax_roformer.py delete mode 100644 src/transformers/models/roformer/modeling_tf_roformer.py delete mode 100644 src/transformers/models/sam/modeling_tf_sam.py delete mode 100644 src/transformers/models/segformer/modeling_tf_segformer.py delete mode 100644 src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py delete mode 100644 src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py delete mode 100755 src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py delete mode 100644 src/transformers/models/swiftformer/modeling_tf_swiftformer.py delete mode 100644 src/transformers/models/swin/modeling_tf_swin.py delete mode 100644 src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py delete mode 100755 src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/t5/modeling_flax_t5.py delete mode 100644 src/transformers/models/t5/modeling_tf_t5.py delete mode 100644 src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/tapas/modeling_tf_tapas.py delete mode 100644 src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py delete mode 100644 src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py delete mode 100644 src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py delete mode 100644 src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py delete mode 100644 src/transformers/models/vit/modeling_flax_vit.py delete mode 100644 src/transformers/models/vit/modeling_tf_vit.py delete mode 100644 src/transformers/models/vit_mae/modeling_tf_vit_mae.py delete mode 100644 src/transformers/models/vivit/convert_vivit_flax_to_pytorch.py delete mode 100644 src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py delete mode 100644 src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py delete mode 100644 src/transformers/models/whisper/modeling_flax_whisper.py delete mode 100644 src/transformers/models/whisper/modeling_tf_whisper.py delete mode 100644 src/transformers/models/xglm/modeling_flax_xglm.py delete mode 100644 src/transformers/models/xglm/modeling_tf_xglm.py delete mode 100644 src/transformers/models/xlm/modeling_tf_xlm.py delete mode 100644 src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py delete mode 100644 src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py delete mode 100755 src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/xlnet/modeling_tf_xlnet.py delete mode 100644 src/transformers/optimization_tf.py delete mode 100644 src/transformers/tf_utils.py delete mode 100644 src/transformers/training_args_tf.py delete mode 100644 src/transformers/utils/dummy_flax_objects.py delete mode 100644 src/transformers/utils/dummy_tf_objects.py delete mode 100644 tests/sagemaker/scripts/tensorflow/run_tf.py delete mode 100644 utils/check_tf_ops.py diff --git a/docs/source/ar/tflite.md b/docs/source/ar/tflite.md deleted file mode 100644 index 5e75c7a10a3c..000000000000 --- a/docs/source/ar/tflite.md +++ /dev/null @@ -1,40 +0,0 @@ -# التصدير إلى TFLite - -[TensorFlow Lite](https://www.tensorflow.org/lite/guide) هو إطار عمل خفيف الوزن لنشر نماذج التعلم الآلي على الأجهزة المحدودة الموارد، مثل الهواتف المحمولة، والأنظمة المدمجة، وأجهزة إنترنت الأشياء (IoT). تم تصميم TFLite لتشغيل النماذج وتحسينها بكفاءة على هذه الأجهزة ذات الطاقة الحاسوبية والذاكرة واستهلاك الطاقة المحدودة. - -يُمثَّل نموذج TensorFlow Lite بتنسيق محمول فعال خاص يُعرَّف بامتداد الملف `.tflite`. - -🤗 Optimum يقدم وظيفة لتصدير نماذج 🤗 Transformers إلى TFLite من خلال الوحدة النمطية `exporters.tflite`. بالنسبة لقائمة هندسات النماذج المدعومة، يرجى الرجوع إلى [وثائق 🤗 Optimum](https://huggingface.co/docs/optimum/exporters/tflite/overview). - -لتصدير نموذج إلى TFLite، قم بتثبيت متطلبات البرنامج المطلوبة: - -```bash -pip install optimum[exporters-tf] -``` - -للاطلاع على جميع المغامﻻت المتاحة، راجع [وثائق 🤗 Optimum](https://huggingface.co/docs/optimum/main/en/exporters/tflite/usage_guides/export_a_model)، أو عرض المساعدة في سطر الأوامر: - -```bash -optimum-cli export tflite --help -``` - -لتصدير نسخة النموذج ل 🤗 Hub، على سبيل المثال، `google-bert/bert-base-uncased`، قم بتشغيل الأمر التالي: - -```bash -optimum-cli export tflite --model google-bert/bert-base-uncased --sequence_length 128 bert_tflite/ -``` - -ستظهر لك السجلات التي تُبيّن التقدم وموقع حفظ ملف `model.tflite` الناتج، كما في المثال التالي: - -```bash -Validating TFLite model... - -[✓] TFLite model output names match reference model (logits) - - Validating TFLite Model output "logits": - -[✓] (1, 128, 30522) matches (1, 128, 30522) - -[x] values not close enough, max diff: 5.817413330078125e-05 (atol: 1e-05) -The TensorFlow Lite export succeeded with the warning: The maximum absolute difference between the output of the reference model and the TFLite exported model is not within the set tolerance 1e-05: -- logits: max diff = 5.817413330078125e-05. - The exported model was saved at: bert_tflite -``` - -يُبيّن المثال أعلاه كيفية تصدير نسخة من النموذج ل 🤗 Hub. عند تصدير نموذج محلي، تأكد أولاً من حفظ ملفات أوزان النموذج المجزء اللغوى في نفس المسار (`local_path`). عند استخدام CLI، قم بتمرير `local_path` إلى معامل `model` بدلاً من اسم النسخة على 🤗 Hub. \ No newline at end of file diff --git a/docs/source/en/tflite.md b/docs/source/en/tflite.md deleted file mode 100644 index 8dfdbeed464d..000000000000 --- a/docs/source/en/tflite.md +++ /dev/null @@ -1,66 +0,0 @@ - - -# LiteRT - -[LiteRT](https://ai.google.dev/edge/litert) (previously known as TensorFlow Lite) is a high-performance runtime designed for on-device machine learning. - -The [Optimum](https://huggingface.co/docs/optimum/index) library exports a model to LiteRT for [many architectures](https://huggingface.co/docs/optimum/exporters/onnx/overview). - -The benefits of exporting to LiteRT include the following. - -- Low-latency, privacy-focused, no internet connectivity required, and reduced model size and power consumption for on-device machine learning. -- Broad platform, model framework, and language support. -- Hardware acceleration for GPUs and Apple Silicon. - -Export a Transformers model to LiteRT with the Optimum CLI. - -Run the command below to install Optimum and the [exporters](https://huggingface.co/docs/optimum/exporters/overview) module for LiteRT. - -```bash -pip install optimum[exporters-tf] -``` - -> [!TIP] -> Refer to the [Export a model to TFLite with optimum.exporters.tflite](https://huggingface.co/docs/optimum/main/en/exporters/tflite/usage_guides/export_a_model) guide for all available arguments or with the command below. -> ```bash -> optimum-cli export tflite --help -> ``` - -Set the `--model` argument to export a from the Hub. - -```bash -optimum-cli export tflite --model google-bert/bert-base-uncased --sequence_length 128 bert_tflite/ -``` - -You should see logs indicating the progress and showing where the resulting `model.tflite` is saved. - -```bash -Validating TFLite model... - -[✓] TFLite model output names match reference model (logits) - - Validating TFLite Model output "logits": - -[✓] (1, 128, 30522) matches (1, 128, 30522) - -[x] values not close enough, max diff: 5.817413330078125e-05 (atol: 1e-05) -The TensorFlow Lite export succeeded with the warning: The maximum absolute difference between the output of the reference model and the TFLite exported model is not within the set tolerance 1e-05: -- logits: max diff = 5.817413330078125e-05. - The exported model was saved at: bert_tflite - ``` - -For local models, make sure the model weights and tokenizer files are saved in the same directory, for example `local_path`. Pass the directory to the `--model` argument and use `--task` to indicate the [task](https://huggingface.co/docs/optimum/exporters/task_manager) a model can perform. If `--task` isn't provided, the model architecture without a task-specific head is used. - -```bash -optimum-cli export tflite --model local_path --task question-answering google-bert/bert-base-uncased --sequence_length 128 bert_tflite/ -``` diff --git a/docs/source/hi/tflite.md b/docs/source/hi/tflite.md deleted file mode 100644 index 5a84bed94266..000000000000 --- a/docs/source/hi/tflite.md +++ /dev/null @@ -1,55 +0,0 @@ - - -# TFLite में निर्यात करें - -[TensorFlow Lite](https://www.tensorflow.org/lite/guide) एक हल्का ढांचा है जो मशीन लर्निंग मॉडल को संसाधन-सीमित उपकरणों, जैसे मोबाइल फोन, एम्बेडेड सिस्टम और इंटरनेट ऑफ थिंग्स (IoT) उपकरणों पर तैनात करने के लिए है। TFLite को इन उपकरणों पर सीमित गणनात्मक शक्ति, मेमोरी और ऊर्जा खपत के साथ मॉडल को कुशलता से ऑप्टिमाइज़ और चलाने के लिए डिज़ाइन किया गया है। एक TensorFlow Lite मॉडल को एक विशेष कुशल पोर्टेबल प्रारूप में दर्शाया जाता है जिसे `.tflite` फ़ाइल एक्सटेंशन द्वारा पहचाना जाता है। - -🤗 Optimum में `exporters.tflite` मॉड्यूल के माध्यम से 🤗 Transformers मॉडल को TFLite में निर्यात करने की कार्यक्षमता है। समर्थित मॉडल आर्किटेक्चर की सूची के लिए, कृपया [🤗 Optimum दस्तावेज़](https://huggingface.co/docs/optimum/exporters/tflite/overview) देखें। - -TFLite में एक मॉडल निर्यात करने के लिए, आवश्यक निर्भरताएँ स्थापित करें: - -```bash -pip install optimum[exporters-tf] -``` - -सभी उपलब्ध तर्कों की जांच करने के लिए, [🤗 Optimum दस्तावेज़](https://huggingface.co/docs/optimum/main/en/exporters/tflite/usage_guides/export_a_model) देखें, -या कमांड लाइन में मदद देखें: - -```bash -optimum-cli export tflite --help -``` - -यदि आप 🤗 Hub से एक मॉडल का चेकपॉइंट निर्यात करना चाहते हैं, उदाहरण के लिए, `google-bert/bert-base-uncased`, निम्नलिखित कमांड चलाएँ: - -```bash -optimum-cli export tflite --model google-bert/bert-base-uncased --sequence_length 128 bert_tflite/ -``` - -आपको प्रगति को दर्शाते हुए लॉग दिखाई देंगे और यह दिखाएंगे कि परिणामस्वरूप `model.tflite` कहाँ सहेजा गया है, जैसे: - -```bash -Validating TFLite model... - -[✓] TFLite model output names match reference model (logits) - - Validating TFLite Model output "logits": - -[✓] (1, 128, 30522) matches (1, 128, 30522) - -[x] values not close enough, max diff: 5.817413330078125e-05 (atol: 1e-05) -The TensorFlow Lite export succeeded with the warning: The maximum absolute difference between the output of the reference model and the TFLite exported model is not within the set tolerance 1e-05: -- logits: max diff = 5.817413330078125e-05. - The exported model was saved at: bert_tflite -``` - -उपरोक्त उदाहरण 🤗 Hub से एक चेकपॉइंट निर्यात करने को दर्शाता है। जब एक स्थानीय मॉडल निर्यात करते हैं, तो पहले सुनिश्चित करें कि आपने मॉडल के वज़न और टोकनाइज़र फ़ाइलों को एक ही निर्देशिका (`local_path`) में सहेजा है। CLI का उपयोग करते समय, चेकपॉइंट नाम के बजाय `model` तर्क में `local_path` पास करें। diff --git a/docs/source/ja/perf_train_tpu_tf.md b/docs/source/ja/perf_train_tpu_tf.md deleted file mode 100644 index 3ffe88267cdd..000000000000 --- a/docs/source/ja/perf_train_tpu_tf.md +++ /dev/null @@ -1,168 +0,0 @@ - - -# Training on TPU with TensorFlow - - - -詳細な説明が不要で、単にTPUのコードサンプルを入手してトレーニングを開始したい場合は、[私たちのTPUの例のノートブックをチェックしてください!](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/tpu_training-tf.ipynb) - - - -### What is a TPU? - -TPUは**Tensor Processing Unit(テンソル処理ユニット)**の略です。これらはGoogleが設計したハードウェアで、ニューラルネットワーク内のテンソル計算を大幅に高速化するために使用されます。これはGPUのようなものです。ネットワークのトレーニングと推論の両方に使用できます。一般的にはGoogleのクラウドサービスを介してアクセスされますが、Google ColabとKaggle Kernelsを通じても無料で小規模のTPUに直接アクセスできます。 - -[🤗 TransformersのすべてのTensorFlowモデルはKerasモデルです](https://huggingface.co/blog/tensorflow-philosophy)ので、この文書のほとんどの方法は一般的にKerasモデル用のTPUトレーニングに適用できます!ただし、TransformersとDatasetsのHuggingFaceエコシステム(hug-o-system?)に固有のポイントもいくつかあり、それについては適用するときにそれを示します。 - -### What kinds of TPU are available? - -新しいユーザーは、さまざまなTPUとそのアクセス方法に関する幅広い情報によく混乱します。理解するための最初の重要な違いは、**TPUノード**と**TPU VM**の違いです。 - -**TPUノード**を使用すると、事実上リモートのTPUに間接的にアクセスします。別個のVMが必要で、ネットワークとデータパイプラインを初期化し、それらをリモートノードに転送します。Google ColabでTPUを使用すると、**TPUノード**スタイルでアクセスしています。 - -TPUノードを使用すると、それに慣れていない人々にはかなり予期しない動作が発生することがあります!特に、TPUはPythonコードを実行しているマシンと物理的に異なるシステムに配置されているため、データはローカルマシンにローカルで格納されているデータパイプラインが完全に失敗します。代わりに、データはGoogle Cloud Storageに格納する必要があります。ここでデータパイプラインはリモートのTPUノードで実行されている場合でも、データにアクセスできます。 - - - -すべてのデータを`np.ndarray`または`tf.Tensor`としてメモリに収めることができる場合、ColabまたはTPUノードを使用している場合でも、データをGoogle Cloud Storageにアップロードせずに`fit()`でトレーニングできます。 - - - - - -**🤗 Hugging Face固有のヒント🤗:** TFコードの例でよく見るであろう`Dataset.to_tf_dataset()`とその高レベルのラッパーである`model.prepare_tf_dataset()`は、TPUノードで失敗します。これは、`tf.data.Dataset`を作成しているにもかかわらず、それが「純粋な」`tf.data`パイプラインではなく、`tf.numpy_function`または`Dataset.from_generator()`を使用して基盤となるHuggingFace `Dataset`からデータをストリームで読み込むことからです。このHuggingFace `Dataset`はローカルディスク上のデータをバックアップしており、リモートTPUノードが読み取ることができないためです。 - - - -TPUにアクセスする第二の方法は、**TPU VM**を介してです。TPU VMを使用する場合、TPUが接続されているマシンに直接接続します。これはGPU VMでトレーニングを行うのと同様です。TPU VMは一般的にデータパイプラインに関しては特に作業がしやすく、上記のすべての警告はTPU VMには適用されません! - -これは主観的な文書ですので、こちらの意見です:**可能な限りTPUノードの使用を避けてください。** TPU VMよりも混乱しやすく、デバッグが難しいです。将来的にはサポートされなくなる可能性もあります - Googleの最新のTPUであるTPUv4は、TPU VMとしてのみアクセスできるため、TPUノードは将来的には「レガシー」のアクセス方法になる可能性が高いです。ただし、無料でTPUにアクセスできるのはColabとKaggle Kernelsの場合があります。その場合、どうしても使用しなければならない場合の取り扱い方法を説明しようとします!詳細は[TPUの例のノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/tpu_training-tf.ipynb)で詳細な説明を確認してください。 - -### What sizes of TPU are available? - -単一のTPU(v2-8/v3-8/v4-8)は8つのレプリカを実行します。TPUは数百から数千のレプリカを同時に実行できる**ポッド**に存在します。単一のTPUよりも多くのTPUを使用するが、ポッド全体ではない場合(たとえばv3-32)、TPUフリートは**ポッドスライス**として参照されます。 - -Colabを介して無料のTPUにアクセスする場合、通常は単一のv2-8 TPUが提供されます。 - - -### I keep hearing about this XLA thing. What’s XLA, and how does it relate to TPUs? - -XLAは、TensorFlowとJAXの両方で使用される最適化コンパイラです。JAXでは唯一のコンパイラであり、TensorFlowではオプションですが(しかしTPUでは必須です!)、Kerasモデルをトレーニングする際に`model.compile()`に引数`jit_compile=True`を渡すことで最も簡単に有効にできます。エラーが発生せず、パフォーマンスが良好であれば、それはTPUに移行する準備が整った良い兆候です! - -TPU上でのデバッグは一般的にCPU/GPUよりも少し難しいため、TPUで試す前にまずCPU/GPUでXLAを使用してコードを実行することをお勧めします。もちろん、長時間トレーニングする必要はありません。モデルとデータパイプラインが期待通りに動作するかを確認するための数ステップだけです。 - - - -XLAコンパイルされたコードは通常高速です。したがって、TPUで実行する予定がない場合でも、`jit_compile=True`を追加することでパフォーマンスを向上させることができます。ただし、以下のXLA互換性に関する注意事項に注意してください! - - - - - -**苦い経験から生まれたヒント:** `jit_compile=True`を使用することは、CPU/GPUコードがXLA互換であることを確認し、速度を向上させる良い方法ですが、実際にTPUでコードを実行する際には多くの問題を引き起こす可能性があります。 XLAコンパイルはTPU上で暗黙的に行われるため、実際にコードをTPUで実行する前にその行を削除することを忘れないでください! - - - -### How do I make my model XLA compatible? - -多くの場合、コードはすでにXLA互換かもしれません!ただし、XLAでは動作する通常のTensorFlowでも動作しないいくつかの要素があります。以下に、3つの主要なルールにまとめています: - - - -**🤗 HuggingFace固有のヒント🤗:** TensorFlowモデルと損失関数をXLA互換に書き直すために多くの努力を払っています。通常、モデルと損失関数はデフォルトでルール#1と#2に従っているため、`transformers`モデルを使用している場合はこれらをスキップできます。ただし、独自のモデルと損失関数を記述する場合は、これらのルールを忘れないでください! - - - -#### XLA Rule #1: Your code cannot have “data-dependent conditionals” - -これは、任意の`if`ステートメントが`tf.Tensor`内の値に依存していない必要があることを意味します。例えば、次のコードブロックはXLAでコンパイルできません! - -```python -if tf.reduce_sum(tensor) > 10: - tensor = tensor / 2.0 -``` - -これは最初は非常に制限的に思えるかもしれませんが、ほとんどのニューラルネットコードはこれを行う必要はありません。通常、この制約を回避するために`tf.cond`を使用するか(ドキュメントはこちらを参照)、条件を削除して代わりに指示変数を使用したりすることができます。次のように: - -```python -sum_over_10 = tf.cast(tf.reduce_sum(tensor) > 10, tf.float32) -tensor = tensor / (1.0 + sum_over_10) -``` - -このコードは、上記のコードとまったく同じ効果を持っていますが、条件を回避することで、XLAで問題なくコンパイルできることを確認します! - -#### XLA Rule #2: Your code cannot have “data-dependent shapes” - -これは、コード内のすべての `tf.Tensor` オブジェクトの形状が、その値に依存しないことを意味します。たとえば、`tf.unique` 関数はXLAでコンパイルできないので、このルールに違反します。なぜなら、これは入力 `Tensor` の一意の値の各インスタンスを含む `tensor` を返すためです。この出力の形状は、入力 `Tensor` の重複具合によって異なるため、XLAはそれを処理しないことになります! - -一般的に、ほとんどのニューラルネットワークコードはデフォルトでルール#2に従います。ただし、いくつかの一般的なケースでは問題が発生することがあります。非常に一般的なケースの1つは、**ラベルマスキング**を使用する場合です。ラベルを無視して損失を計算する場所を示すために、ラベルを負の値に設定する方法です。NumPyまたはPyTorchのラベルマスキングをサポートする損失関数を見ると、次のような[ブールインデックス](https://numpy.org/doc/stable/user/basics.indexing.html#boolean-array-indexing)を使用したコードがよく見られます: - - -```python -label_mask = labels >= 0 -masked_outputs = outputs[label_mask] -masked_labels = labels[label_mask] -loss = compute_loss(masked_outputs, masked_labels) -mean_loss = torch.mean(loss) -``` - -このコードはNumPyやPyTorchでは完全に機能しますが、XLAでは動作しません!なぜなら、`masked_outputs`と`masked_labels`の形状はマスクされた位置の数に依存するため、これは**データ依存の形状**になります。ただし、ルール#1と同様に、このコードを書き直して、データ依存の形状なしでまったく同じ出力を生成できることがあります。 - - -```python -label_mask = tf.cast(labels >= 0, tf.float32) -loss = compute_loss(outputs, labels) -loss = loss * label_mask # Set negative label positions to 0 -mean_loss = tf.reduce_sum(loss) / tf.reduce_sum(label_mask) -``` - - -ここでは、データ依存の形状を避けるために、各位置で損失を計算してから、平均を計算する際に分子と分母の両方でマスクされた位置をゼロ化する方法を紹介します。これにより、最初のアプローチとまったく同じ結果が得られますが、XLA互換性を維持します。注意点として、ルール#1と同じトリックを使用します - `tf.bool`を`tf.float32`に変換して指標変数として使用します。これは非常に便利なトリックですので、自分のコードをXLAに変換する必要がある場合には覚えておいてください! - -#### XLA Rule #3: XLA will need to recompile your model for every different input shape it sees - -これは重要なルールです。これはつまり、入力形状が非常に変動的な場合、XLA はモデルを何度も再コンパイルする必要があるため、大きなパフォーマンスの問題が発生する可能性があるということです。これは NLP モデルで一般的に発生し、トークナイズ後の入力テキストの長さが異なる場合があります。他のモダリティでは、静的な形状が一般的であり、このルールはほとんど問題になりません。 - -ルール#3を回避する方法は何でしょうか?鍵は「パディング」です - すべての入力を同じ長さにパディングし、次に「attention_mask」を使用することで、可変形状と同じ結果を得ることができますが、XLA の問題は発生しません。ただし、過度のパディングも深刻な遅延を引き起こす可能性があります - データセット全体で最大の長さにすべてのサンプルをパディングすると、多くの計算とメモリを無駄にする可能性があります! - -この問題には完璧な解決策はありませんが、いくつかのトリックを試すことができます。非常に便利なトリックの1つは、**バッチのサンプルを32または64トークンの倍数までパディングする**ことです。これにより、トークン数がわずかに増加するだけで、すべての入力形状が32または64の倍数である必要があるため、一意の入力形状の数が大幅に減少します。一意の入力形状が少ないと、XLA の再コンパイルが少なくなります! - - - -**🤗 HuggingFace に関する具体的なヒント🤗:** 弊社のトークナイザーとデータコレクターには、ここで役立つメソッドがあります。トークナイザーを呼び出す際に `padding="max_length"` または `padding="longest"` を使用して、パディングされたデータを出力するように設定できます。トークナイザーとデータコレクターには、一意の入力形状の数を減らすのに役立つ `pad_to_multiple_of` 引数もあります! - - - -### How do I actually train my model on TPU? - -一度トレーニングが XLA 互換性があることを確認し、(TPU Node/Colab を使用する場合は)データセットが適切に準備されている場合、TPU 上で実行することは驚くほど簡単です!コードを変更する必要があるのは、いくつかの行を追加して TPU を初期化し、モデルとデータセットが `TPUStrategy` スコープ内で作成されるようにすることだけです。これを実際に見るには、[TPU のサンプルノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/tpu_training-tf.ipynb)をご覧ください! - -### Summary - -ここでは多くの情報が提供されましたので、TPU でモデルをトレーニングする際に以下のチェックリストを使用できます: - -- コードが XLA の三つのルールに従っていることを確認します。 -- CPU/GPU で `jit_compile=True` を使用してモデルをコンパイルし、XLA でトレーニングできることを確認します。 -- データセットをメモリに読み込むか、TPU 互換のデータセット読み込みアプローチを使用します([ノートブックを参照](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/tpu_training-tf.ipynb))。 -- コードを Colab(アクセラレータを「TPU」に設定)または Google Cloud の TPU VM に移行します。 -- TPU 初期化コードを追加します([ノートブックを参照](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/tpu_training-tf.ipynb))。 -- `TPUStrategy` を作成し、データセットの読み込みとモデルの作成が `strategy.scope()` 内で行われることを確認します([ノートブックを参照](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/tpu_training-tf.ipynb))。 -- TPU に移行する際に `jit_compile=True` を外すのを忘れないでください! -- 🙏🙏🙏🥺🥺🥺 -- `model.fit()` を呼び出します。 -- おめでとうございます! - - diff --git a/docs/source/ja/tf_xla.md b/docs/source/ja/tf_xla.md deleted file mode 100644 index 1f5a2af1a5a2..000000000000 --- a/docs/source/ja/tf_xla.md +++ /dev/null @@ -1,179 +0,0 @@ - - -# XLA Integration for TensorFlow Models - -[[open-in-colab]] - -加速線形代数(Accelerated Linear Algebra)、通称XLAは、TensorFlowモデルのランタイムを高速化するためのコンパイラです。[公式ドキュメント](https://www.tensorflow.org/xla)によれば、XLA(Accelerated Linear Algebra)は線形代数のためのドメイン固有のコンパイラで、TensorFlowモデルを潜在的にソースコードの変更なしで高速化できます。 - -TensorFlowでXLAを使用するのは簡単です。XLAは`tensorflow`ライブラリ内にパッケージ化されており、[`tf.function`](https://www.tensorflow.org/guide/intro_to_graphs)などのグラフを作成する関数内で`jit_compile`引数を使用してトリガーできます。`fit()`や`predict()`などのKerasメソッドを使用する場合、`model.compile()`に`jit_compile`引数を渡すだけでXLAを有効にできます。ただし、XLAはこれらのメソッドに限定されているわけではありません。任意の`tf.function`を高速化するためにも使用できます。 - -🤗 Transformers内のいくつかのTensorFlowメソッドは、XLAと互換性があるように書き直されています。これには、[GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)、[T5](https://huggingface.co/docs/transformers/model_doc/t5)、[OPT](https://huggingface.co/docs/transformers/model_doc/opt)などのテキスト生成モデルや、[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)などの音声処理モデルも含まれます。 - -速度向上の具体的な量はモデルに非常に依存しますが、🤗 Transformers内のTensorFlowテキスト生成モデルでは、約100倍の速度向上を確認しています。このドキュメントでは、これらのモデルにXLAを使用して最大のパフォーマンスを得る方法を説明します。また、ベンチマークとXLA統合のデザイン哲学について詳しく学びたい場合の追加リソースへのリンクも提供します。 - -## Running TF functions with XLA - -以下のTensorFlowモデルを考えてみましょう: - - -```py -import tensorflow as tf - -model = tf.keras.Sequential( - [tf.keras.layers.Dense(10, input_shape=(10,), activation="relu"), tf.keras.layers.Dense(5, activation="softmax")] -) -``` - -上記のモデルは、次元が`(10, )`の入力を受け入れます。このモデルをフォワードパスで実行するには、次のようにします: - - -```py -# Generate random inputs for the model. -batch_size = 16 -input_vector_dim = 10 -random_inputs = tf.random.normal((batch_size, input_vector_dim)) - -# Run a forward pass. -_ = model(random_inputs) -``` - -XLAでコンパイルされた関数を使用してフォワードパスを実行するには、以下のようにします: - - -```py -xla_fn = tf.function(model, jit_compile=True) -_ = xla_fn(random_inputs) -``` - -`model`のデフォルトの `call()` 関数はXLAグラフをコンパイルするために使用されます。ただし、XLAにコンパイルしたい他のモデル関数がある場合、それも可能です。以下はその方法です: - - -```py -my_xla_fn = tf.function(model.my_xla_fn, jit_compile=True) -``` - -## Running a TF text generation model with XLA from 🤗 Transformers - -🤗 Transformers内でXLAでの高速化された生成を有効にするには、最新バージョンの`transformers`がインストールされている必要があります。次のコマンドを実行してインストールできます: - -```bash -pip install transformers --upgrade -``` - -次に、次のコードを実行できます: - - -```py -import tensorflow as tf -from transformers import AutoTokenizer, TFAutoModelForCausalLM - -# Will error if the minimal version of Transformers is not installed. -from transformers.utils import check_min_version - -check_min_version("4.21.0") - - -tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="") -model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2") -input_string = ["TensorFlow is"] - -# One line to create an XLA generation function -xla_generate = tf.function(model.generate, jit_compile=True) - -tokenized_input = tokenizer(input_string, return_tensors="tf") -generated_tokens = xla_generate(**tokenized_input, num_beams=2) - -decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) -print(f"Generated -- {decoded_text}") -# Generated -- TensorFlow is an open-source, open-source, distributed-source application # framework for the -``` - -`generate()`でXLAを有効にするのは、たった一行のコードです。コードの残り部分は変更されていません。ただし、XLA固有のいくつかの注意点が上記のコードスニペットにあります。これらに注意する必要があり、XLAがもたらす速度向上を実現するためにそれらを把握することが重要です。次のセクションでこれらについて詳しく説明します。 - - -## Gotchas to be aware of - -XLAを有効にした関数(上記の`xla_generate()`など)を初めて実行すると、内部で計算グラフを推論しようとしますが、これは時間がかかります。このプロセスは["トレーシング"(tracing)](https://www.tensorflow.org/guide/intro_to_graphs#when_is_a_function_tracing)として知られています。 - -生成時間が高速ではないことに気付くかもしれません。`xla_generate()`(または他のXLA対応関数)の連続呼び出しでは、関数への入力が最初に計算グラフが構築されたときと同じ形状に従っている場合、計算グラフを推論する必要はありません。これは、入力形状が固定されているモダリティ(例:画像)には問題ありませんが、変数の入力形状モダリティ(例:テキスト)を扱う場合には注意が必要です。 - -`xla_generate()`が常に同じ入力形状で動作するようにするには、トークナイザを呼び出す際に`padding`引数を指定できます。 - -```py -import tensorflow as tf -from transformers import AutoTokenizer, TFAutoModelForCausalLM - -tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="") -model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2") -input_string = ["TensorFlow is"] - -xla_generate = tf.function(model.generate, jit_compile=True) - -# Here, we call the tokenizer with padding options. -tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf") - -generated_tokens = xla_generate(**tokenized_input, num_beams=2) -decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) -print(f"Generated -- {decoded_text}") -``` - -これにより、`xla_generate()`への入力が常にトレースされた形状の入力を受け取ることを確認し、生成時間の高速化を実現できます。以下のコードでこれを確認できます: - -```py -import time -import tensorflow as tf -from transformers import AutoTokenizer, TFAutoModelForCausalLM - -tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="") -model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2") - -xla_generate = tf.function(model.generate, jit_compile=True) - -for input_string in ["TensorFlow is", "TensorFlow is a", "TFLite is a"]: - tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf") - start = time.time_ns() - generated_tokens = xla_generate(**tokenized_input, num_beams=2) - end = time.time_ns() - print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n") -``` - -Tesla T4 GPUを使用すると、次のような出力が期待されます: - -```bash -Execution time -- 30819.6 ms - -Execution time -- 79.0 ms - -Execution time -- 78.9 ms -``` - -最初の`xla_generate()`呼び出しはトレーシングのために時間がかかりますが、連続する呼び出しは桁違いに高速です。生成オプションのいかなる変更も、再トレーシングを引き起こし、生成時間の遅延を引き起こすことに注意してください。 - -このドキュメントでは、🤗 Transformersが提供するテキスト生成オプションをすべて網羅していません。高度なユースケースについてはドキュメンテーションを参照することをお勧めします。 - -## Additional Resources - -ここでは、🤗 Transformersと一般的なXLAについてさらに詳しく学びたい場合のいくつかの追加リソースを提供します。 - -* [このColab Notebook](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/91_tf_xla_generate.ipynb)では、XLA対応のエンコーダーデコーダー([T5](https://huggingface.co/docs/transformers/model_doc/t5)など)およびデコーダー専用([GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)など)テキスト生成モデルを試すための対話型デモが提供されています。 -* [このブログ記事](https://huggingface.co/blog/tf-xla-generate)では、XLA対応モデルの比較ベンチマークの概要と、TensorFlowでのXLAについての友好的な紹介が提供されています。 -* [このブログ記事](https://blog.tensorflow.org/2022/11/how-hugging-face-improved-text-generation-performance-with-xla.html)では、🤗 TransformersのTensorFlowモデルにXLAサポートを追加する際の設計哲学について説明しています。 -* 一般的なXLAとTensorFlowグラフについて詳しく学ぶためのおすすめの投稿: - * [XLA: 機械学習用の最適化コンパイラ](https://www.tensorflow.org/xla) - * [グラフと`tf.function`の紹介](https://www.tensorflow.org/guide/intro_to_graphs) - * [`tf.function`を使用したパフォーマンス向上](https://www.tensorflow.org/guide/function) diff --git a/docs/source/ja/tflite.md b/docs/source/ja/tflite.md deleted file mode 100644 index ad3e9a3f484e..000000000000 --- a/docs/source/ja/tflite.md +++ /dev/null @@ -1,58 +0,0 @@ - - -# Export to TFLite - -[TensorFlow Lite](https://www.tensorflow.org/lite/guide)は、モバイルフォン、組み込みシステム、およびモノのインターネット(IoT)デバイスなど、リソースに制約のあるデバイスに機械学習モデルを展開するための軽量なフレームワークです。TFLiteは、計算能力、メモリ、および電力消費が限られているこれらのデバイス上でモデルを効率的に最適化して実行するために設計されています。 -TensorFlow Liteモデルは、`.tflite`ファイル拡張子で識別される特別な効率的なポータブル形式で表されます。 - -🤗 Optimumは、🤗 TransformersモデルをTFLiteにエクスポートするための機能を`exporters.tflite`モジュールを介して提供しています。サポートされているモデルアーキテクチャのリストについては、[🤗 Optimumのドキュメント](https://huggingface.co/docs/optimum/exporters/tflite/overview)をご参照ください。 - -モデルをTFLiteにエクスポートするには、必要な依存関係をインストールしてください: - - -```bash -pip install optimum[exporters-tf] -``` - -すべての利用可能な引数を確認するには、[🤗 Optimumドキュメント](https://huggingface.co/docs/optimum/main/en/exporters/tflite/usage_guides/export_a_model)を参照するか、コマンドラインでヘルプを表示してください: - -```bash -optimum-cli export tflite --help -``` - -🤗 Hubからモデルのチェックポイントをエクスポートするには、例えば `google-bert/bert-base-uncased` を使用する場合、次のコマンドを実行します: - -```bash -optimum-cli export tflite --model google-bert/bert-base-uncased --sequence_length 128 bert_tflite/ -``` - -進行状況を示すログが表示され、生成された `model.tflite` が保存された場所も表示されるはずです: - -```bash -Validating TFLite model... - -[✓] TFLite model output names match reference model (logits) - - Validating TFLite Model output "logits": - -[✓] (1, 128, 30522) matches (1, 128, 30522) - -[x] values not close enough, max diff: 5.817413330078125e-05 (atol: 1e-05) -The TensorFlow Lite export succeeded with the warning: The maximum absolute difference between the output of the reference model and the TFLite exported model is not within the set tolerance 1e-05: -- logits: max diff = 5.817413330078125e-05. - The exported model was saved at: bert_tflite - ``` - -上記の例は🤗 Hubからチェックポイントをエクスポートする方法を示しています。ローカルモデルをエクスポートする場合、まずモデルの重みファイルとトークナイザファイルを同じディレクトリ(`local_path`)に保存したことを確認してください。CLIを使用する場合、🤗 Hubのチェックポイント名の代わりに`model`引数に`local_path`を渡します。 - - diff --git a/docs/source/ko/tflite.md b/docs/source/ko/tflite.md deleted file mode 100644 index 464106a6b7c2..000000000000 --- a/docs/source/ko/tflite.md +++ /dev/null @@ -1,62 +0,0 @@ - - -# TFLite로 내보내기[[export-to-tflite]] - -[TensorFlow Lite](https://www.tensorflow.org/lite/guide)는 자원이 제한된 휴대폰, 임베디드 시스템, 사물인터넷(IoT) 기기에서 -기계학습 모델을 배포하기 위한 경량 프레임워크입니다. -TFLite는 연산 능력, 메모리, 전력 소비가 제한된 기기에서 모델을 효율적으로 최적화하고 실행하기 위해 -설계되었습니다. -TensorFlow Lite 모델은 `.tflite` 파일 확장자로 식별되는 특수하고 효율적인 휴대용 포맷으로 표현됩니다. - -🤗 Optimum은 `exporters.tflite` 모듈로 🤗 Transformers 모델을 TFLite로 내보내는 기능을 제공합니다. -지원되는 모델 아키텍처 목록은 [🤗 Optimum 문서](https://huggingface.co/docs/optimum/exporters/tflite/overview)를 참고하세요. - -모델을 TFLite로 내보내려면, 필요한 종속성을 설치하세요: - -```bash -pip install optimum[exporters-tf] -``` - -모든 사용 가능한 인수를 확인하려면, [🤗 Optimum 문서](https://huggingface.co/docs/optimum/main/en/exporters/tflite/usage_guides/export_a_model)를 참고하거나 -터미널에서 도움말을 살펴보세요: - -```bash -optimum-cli export tflite --help -``` - -예를 들어 🤗 Hub에서의 `google-bert/bert-base-uncased` 모델 체크포인트를 내보내려면, 다음 명령을 실행하세요: - -```bash -optimum-cli export tflite --model google-bert/bert-base-uncased --sequence_length 128 bert_tflite/ -``` - -다음과 같이 진행 상황을 나타내는 로그와 결과물인 `model.tflite`가 저장된 위치를 보여주는 로그가 표시됩니다: - -```bash -Validating TFLite model... - -[✓] TFLite model output names match reference model (logits) - - Validating TFLite Model output "logits": - -[✓] (1, 128, 30522) matches (1, 128, 30522) - -[x] values not close enough, max diff: 5.817413330078125e-05 (atol: 1e-05) -The TensorFlow Lite export succeeded with the warning: The maximum absolute difference between the output of the reference model and the TFLite exported model is not within the set tolerance 1e-05: -- logits: max diff = 5.817413330078125e-05. - The exported model was saved at: bert_tflite - ``` - -위 예제는 🤗 Hub에서의 체크포인트를 내보내는 방법을 보여줍니다. -로컬 모델을 내보낸다면, 먼저 모델 가중치와 토크나이저 파일이 모두 같은 디렉터리( `local_path` )에 저장됐는지 확인하세요. -CLI를 사용할 때, 🤗 Hub에서의 체크포인트 이름 대신 `model` 인수에 `local_path`를 전달하면 됩니다. \ No newline at end of file diff --git a/docs/source/zh/tf_xla.md b/docs/source/zh/tf_xla.md deleted file mode 100644 index 2e5b444d876c..000000000000 --- a/docs/source/zh/tf_xla.md +++ /dev/null @@ -1,179 +0,0 @@ - - -# 用于 TensorFlow 模型的 XLA 集成 - -[[open-in-colab]] - -加速线性代数,也称为XLA,是一个用于加速TensorFlow模型运行时间的编译器。从[官方文档](https://www.tensorflow.org/xla)中可以看到: - -XLA(加速线性代数)是一种针对线性代数的特定领域编译器,可以在可能不需要更改源代码的情况下加速TensorFlow模型。 - -在TensorFlow中使用XLA非常简单——它包含在`tensorflow`库中,并且可以使用任何图创建函数中的`jit_compile`参数来触发,例如[`tf.function`](https://www.tensorflow.org/guide/intro_to_graphs)。在使用Keras方法如`fit()`和`predict()`时,只需将`jit_compile`参数传递给`model.compile()`即可启用XLA。然而,XLA不仅限于这些方法 - 它还可以用于加速任何任意的`tf.function`。 - -在🤗 Transformers中,几个TensorFlow方法已经被重写为与XLA兼容,包括[GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)、[T5](https://huggingface.co/docs/transformers/model_doc/t5)和[OPT](https://huggingface.co/docs/transformers/model_doc/opt)等文本生成模型,以及[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)等语音处理模型。 - -虽然确切的加速倍数很大程度上取决于模型,但对于🤗 Transformers中的TensorFlow文本生成模型,我们注意到速度提高了约100倍。本文档将解释如何在这些模型上使用XLA获得最大的性能。如果您有兴趣了解更多关于基准测试和我们在XLA集成背后的设计哲学的信息,我们还将提供额外的资源链接。 - - -## 使用 XLA 运行 TensorFlow 函数 - -让我们考虑以下TensorFlow 中的模型: - -```py -import tensorflow as tf - -model = tf.keras.Sequential( - [tf.keras.layers.Dense(10, input_shape=(10,), activation="relu"), tf.keras.layers.Dense(5, activation="softmax")] -) -``` - -上述模型接受维度为 `(10,)` 的输入。我们可以像下面这样使用模型进行前向传播: - -```py -# Generate random inputs for the model. -batch_size = 16 -input_vector_dim = 10 -random_inputs = tf.random.normal((batch_size, input_vector_dim)) - -# Run a forward pass. -_ = model(random_inputs) -``` - -为了使用 XLA 编译的函数运行前向传播,我们需要执行以下操作: - -```py -xla_fn = tf.function(model, jit_compile=True) -_ = xla_fn(random_inputs) -``` - -`model`的默认`call()`函数用于编译XLA图。但如果你想将其他模型函数编译成XLA,也是可以的,如下所示: - -```py -my_xla_fn = tf.function(model.my_xla_fn, jit_compile=True) -``` - -## 在🤗 Transformers库中使用XLA运行TensorFlow文本生成模型 - -要在🤗 Transformers中启用XLA加速生成,您需要安装最新版本的`transformers`。您可以通过运行以下命令来安装它: - -```bash -pip install transformers --upgrade -``` - -然后您可以运行以下代码: - -```py -import tensorflow as tf -from transformers import AutoTokenizer, TFAutoModelForCausalLM - -# Will error if the minimal version of Transformers is not installed. -from transformers.utils import check_min_version - -check_min_version("4.21.0") - - -tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="") -model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2") -input_string = ["TensorFlow is"] - -# One line to create an XLA generation function -xla_generate = tf.function(model.generate, jit_compile=True) - -tokenized_input = tokenizer(input_string, return_tensors="tf") -generated_tokens = xla_generate(**tokenized_input, num_beams=2) - -decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) -print(f"Generated -- {decoded_text}") -# Generated -- TensorFlow is an open-source, open-source, distributed-source application # framework for the -``` - -正如您所注意到的,在`generate()`上启用XLA只需要一行代码。其余部分代码保持不变。然而,上面的代码片段中有一些与XLA相关的注意事项。您需要了解这些注意事项,以充分利用XLA可能带来的性能提升。我们将在下面的部分讨论这些内容。 - -## 需要关注的注意事项 - -当您首次执行启用XLA的函数(如上面的`xla_generate()`)时,它将在内部尝试推断计算图,这是一个耗时的过程。这个过程被称为[“tracing”](https://www.tensorflow.org/guide/intro_to_graphs#when_is_a_function_tracing)。 - -您可能会注意到生成时间并不快。连续调用`xla_generate()`(或任何其他启用了XLA的函数)不需要再次推断计算图,只要函数的输入与最初构建计算图时的形状相匹配。对于具有固定输入形状的模态(例如图像),这不是问题,但如果您正在处理具有可变输入形状的模态(例如文本),则必须注意。 - -为了确保`xla_generate()`始终使用相同的输入形状,您可以在调用`tokenizer`时指定`padding`参数。 - -```py -import tensorflow as tf -from transformers import AutoTokenizer, TFAutoModelForCausalLM - -tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="") -model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2") -input_string = ["TensorFlow is"] - -xla_generate = tf.function(model.generate, jit_compile=True) - -# Here, we call the tokenizer with padding options. -tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf") - -generated_tokens = xla_generate(**tokenized_input, num_beams=2) -decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) -print(f"Generated -- {decoded_text}") -``` - -通过这种方式,您可以确保`xla_generate()`的输入始终具有它跟踪的形状,从而加速生成时间。您可以使用以下代码来验证这一点: - -```py -import time -import tensorflow as tf -from transformers import AutoTokenizer, TFAutoModelForCausalLM - -tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="") -model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2") - -xla_generate = tf.function(model.generate, jit_compile=True) - -for input_string in ["TensorFlow is", "TensorFlow is a", "TFLite is a"]: - tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf") - start = time.time_ns() - generated_tokens = xla_generate(**tokenized_input, num_beams=2) - end = time.time_ns() - print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n") -``` - -在Tesla T4 GPU上,您可以期望如下的输出: - -```bash -Execution time -- 30819.6 ms - -Execution time -- 79.0 ms - -Execution time -- 78.9 ms -``` - -第一次调用`xla_generate()`会因为`tracing`而耗时,但后续的调用会快得多。请注意,任何时候对生成选项的更改都会触发重新`tracing`,从而导致生成时间减慢。 - -在本文档中,我们没有涵盖🤗 Transformers提供的所有文本生成选项。我们鼓励您阅读文档以了解高级用例。 - -## 附加资源 - -以下是一些附加资源,如果您想深入了解在🤗 Transformers和其他库下使用XLA: - -* [这个Colab Notebook](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/91_tf_xla_generate.ipynb) 提供了一个互动演示,让您可以尝试使用XLA兼容的编码器-解码器(例如[T5](https://huggingface.co/docs/transformers/model_doc/t5))和仅解码器(例如[GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2))文本生成模型。 - -* [这篇博客文章](https://huggingface.co/blog/tf-xla-generate) 提供了XLA兼容模型的比较基准概述,以及关于在TensorFlow中使用XLA的友好介绍。 - -* [这篇博客文章](https://blog.tensorflow.org/2022/11/how-hugging-face-improved-text-generation-performance-with-xla.html) 讨论了我们在🤗 Transformers中为TensorFlow模型添加XLA支持的设计理念。 - -* 推荐用于更多学习XLA和TensorFlow图的资源: - * [XLA:面向机器学习的优化编译器](https://www.tensorflow.org/xla) - * [图和tf.function简介](https://www.tensorflow.org/guide/intro_to_graphs) - * [使用tf.function获得更好的性能](https://www.tensorflow.org/guide/function) \ No newline at end of file diff --git a/docs/source/zh/tflite.md b/docs/source/zh/tflite.md deleted file mode 100644 index f0280156def4..000000000000 --- a/docs/source/zh/tflite.md +++ /dev/null @@ -1,54 +0,0 @@ - - -# 导出为 TFLite - -[TensorFlow Lite](https://www.tensorflow.org/lite/guide) 是一个轻量级框架,用于资源受限的设备上,如手机、嵌入式系统和物联网(IoT)设备,部署机器学习模型。TFLite 旨在在计算能力、内存和功耗有限的设备上优化和高效运行模型。模型以一种特殊的高效可移植格式表示,其文件扩展名为 `.tflite`。 - -🤗 Optimum 通过 `exporters.tflite` 模块提供将 🤗 Transformers 模型导出至 TFLite 格式的功能。请参考 [🤗 Optimum 文档](https://huggingface.co/docs/optimum/exporters/tflite/overview) 以获取支持的模型架构列表。 - -要将模型导出为 TFLite 格式,请安装所需的依赖项: - -```bash -pip install optimum[exporters-tf] -``` - -请参阅 [🤗 Optimum 文档](https://huggingface.co/docs/optimum/main/en/exporters/tflite/usage_guides/export_a_model) 以查看所有可用参数,或者在命令行中查看帮助: - -```bash -optimum-cli export tflite --help -``` - -运行以下命令,以从 🤗 Hub 导出模型的检查点(checkpoint),以 `google-bert/bert-base-uncased` 为例: - -```bash -optimum-cli export tflite --model google-bert/bert-base-uncased --sequence_length 128 bert_tflite/ -``` - -你应该能在日志中看到导出进度以及生成的 `model.tflite` 文件的保存位置,如下所示: - -```bash -Validating TFLite model... - -[✓] TFLite model output names match reference model (logits) - - Validating TFLite Model output "logits": - -[✓] (1, 128, 30522) matches (1, 128, 30522) - -[x] values not close enough, max diff: 5.817413330078125e-05 (atol: 1e-05) -The TensorFlow Lite export succeeded with the warning: The maximum absolute difference between the output of the reference model and the TFLite exported model is not within the set tolerance 1e-05: -- logits: max diff = 5.817413330078125e-05. - The exported model was saved at: bert_tflite -``` - -上面的示例说明了从 🤗 Hub 导出检查点的过程。导出本地模型时,首先需要确保将模型的权重和分词器文件保存在同一目录(`local_path`)中。在使用 CLI(命令行)时,将 `local_path` 传递给 `model` 参数,而不是 🤗 Hub 上的检查点名称。 \ No newline at end of file diff --git a/src/transformers/activations_tf.py b/src/transformers/activations_tf.py deleted file mode 100644 index 8dccf6c4f46b..000000000000 --- a/src/transformers/activations_tf.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import tensorflow as tf -from packaging.version import parse - - -try: - import tf_keras as keras -except (ModuleNotFoundError, ImportError): - import keras - - if parse(keras.__version__).major > 2: - raise ValueError( - "Your currently installed version of Keras is Keras 3, but this is not yet supported in " - "Transformers. Please install the backwards-compatible tf-keras package with " - "`pip install tf-keras`." - ) - - -def _gelu(x): - """ - Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when - initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see - https://huggingface.co/papers/1606.08415 - """ - x = tf.convert_to_tensor(x) - cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) - - return x * cdf - - -def _gelu_new(x): - """ - Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://huggingface.co/papers/1606.0841 - - Args: - x: float Tensor to perform activation - - Returns: - `x` with the GELU activation applied. - """ - x = tf.convert_to_tensor(x) - pi = tf.cast(math.pi, x.dtype) - coeff = tf.cast(0.044715, x.dtype) - cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) - - return x * cdf - - -def mish(x): - x = tf.convert_to_tensor(x) - - return x * tf.tanh(tf.math.softplus(x)) - - -def gelu_fast(x): - x = tf.convert_to_tensor(x) - coeff1 = tf.cast(0.044715, x.dtype) - coeff2 = tf.cast(0.7978845608, x.dtype) - - return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x))) - - -def quick_gelu(x): - x = tf.convert_to_tensor(x) - coeff = tf.cast(1.702, x.dtype) - return x * tf.math.sigmoid(coeff * x) - - -def gelu_10(x): - """ - Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as - it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to - https://huggingface.co/papers/2004.09602 - - Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when - initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see - https://huggingface.co/papers/1606.08415 :param x: :return: - """ - return tf.clip_by_value(_gelu(x), -10, 10) - - -def glu(x, axis=-1): - """ - Gated Linear Unit. Implementation as defined in the original paper (see https://huggingface.co/papers/1612.08083), where - the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B). - - Args: - `x`: float Tensor to perform activation - `axis`: dimension across which `x` be split in half - - Returns: - `x` with the GLU activation applied (with its size halved across the dimension `axis`). - """ - a, b = tf.split(x, 2, axis=axis) - return a * tf.math.sigmoid(b) - - -if parse(tf.version.VERSION) >= parse("2.4"): - - def approximate_gelu_wrap(x): - return keras.activations.gelu(x, approximate=True) - - gelu = keras.activations.gelu - gelu_new = approximate_gelu_wrap -else: - gelu = _gelu - gelu_new = _gelu_new - - -ACT2FN = { - "gelu": gelu, - "gelu_10": gelu_10, - "gelu_fast": gelu_fast, - "gelu_new": gelu_new, - "glu": glu, - "mish": mish, - "quick_gelu": quick_gelu, - "relu": keras.activations.relu, - "sigmoid": keras.activations.sigmoid, - "silu": keras.activations.swish, - "swish": keras.activations.swish, - "tanh": keras.activations.tanh, -} - - -def get_tf_activation(activation_string): - if activation_string in ACT2FN: - return ACT2FN[activation_string] - else: - raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") diff --git a/src/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py b/src/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py deleted file mode 100755 index e2c825a45b60..000000000000 --- a/src/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert Seq2Seq TF Hub checkpoint.""" - -import argparse - -from . import ( - BertConfig, - BertGenerationConfig, - BertGenerationDecoder, - BertGenerationEncoder, - load_tf_weights_in_bert_generation, - logging, -) - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder): - # Initialise PyTorch model - bert_config = BertConfig.from_pretrained( - "google-bert/bert-large-cased", - vocab_size=vocab_size, - max_position_embeddings=512, - is_decoder=True, - add_cross_attention=True, - ) - bert_config_dict = bert_config.to_dict() - del bert_config_dict["type_vocab_size"] - config = BertGenerationConfig(**bert_config_dict) - if is_encoder: - model = BertGenerationEncoder(config) - else: - model = BertGenerationDecoder(config) - print(f"Building PyTorch model from configuration: {config}") - - # Load weights from tf checkpoint - load_tf_weights_in_bert_generation( - model, - tf_hub_path, - model_class="bert", - is_encoder_named_decoder=is_encoder_named_decoder, - is_encoder=is_encoder, - ) - - # Save pytorch-model - print(f"Save PyTorch model and config to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - parser.add_argument( - "--is_encoder_named_decoder", - action="store_true", - help="If decoder has to be renamed to encoder in PyTorch model.", - ) - parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.") - parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model") - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch( - args.tf_hub_path, - args.pytorch_dump_path, - args.is_encoder_named_decoder, - args.vocab_size, - is_encoder=args.is_encoder, - ) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py deleted file mode 100644 index 08fa411dc6f5..000000000000 --- a/src/transformers/generation/flax_logits_process.py +++ /dev/null @@ -1,544 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect - -import jax -import jax.lax as lax -import jax.numpy as jnp -from jax.experimental import sparse - -from ..utils import add_start_docstrings -from ..utils.logging import get_logger - - -logger = get_logger(__name__) - - -LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - scores (`jnp.ndarray` of shape `(batch_size, config.vocab_size)`): - Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam - search or log softmax for each vocabulary token when using beam search - kwargs (`dict[str, Any]`, *optional*): - Additional logits processor specific kwargs. - - Return: - `jnp.ndarray` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. - -""" - - -class FlaxLogitsProcessor: - """Abstract base class for all logit processors that can be applied during generation.""" - - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray: - """Flax method for processing logits.""" - raise NotImplementedError( - f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." - ) - - -class FlaxLogitsWarper: - """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" - - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray: - """Flax method for warping logits.""" - raise NotImplementedError( - f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." - ) - - -class FlaxLogitsProcessorList(list): - """ - This class can be used to create a list of [`FlaxLogitsProcessor`] or [`FlaxLogitsWarper`] to subsequently process - a `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each - [`FlaxLogitsProcessor`] or [`FlaxLogitsWarper`] to the inputs. - """ - - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **kwargs) -> jnp.ndarray: - for processor in self: - function_args = inspect.signature(processor.__call__).parameters - if len(function_args) > 3: - if not all(arg in kwargs for arg in list(function_args.keys())[2:]): - raise ValueError( - f"Make sure that all the required parameters: {list(function_args.keys())} for " - f"{processor.__class__} are passed to the logits processor." - ) - scores = processor(input_ids, scores, cur_len, **kwargs) - else: - scores = processor(input_ids, scores, cur_len) - return scores - - -class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): - r""" - [`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution). - - Args: - temperature (`float`): - The value used to module the logits distribution. - """ - - def __init__(self, temperature: float): - if not isinstance(temperature, float) or not (temperature > 0): - raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") - - self.temperature = temperature - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - scores = scores / self.temperature - return scores - - -class FlaxTopPLogitsWarper(FlaxLogitsWarper): - """ - [`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. - - Args: - top_p (`float`): - If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or - higher are kept for generation. - filter_value (`float`, *optional*, defaults to -inf): - All filtered values will be set to this float value. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered. - """ - - def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): - if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): - raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") - if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1): - raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}") - - self.top_p = top_p - self.filter_value = filter_value - self.min_tokens_to_keep = min_tokens_to_keep - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1]) - - mask_scores = jnp.full_like(scores, self.filter_value) - cumulative_probs = jax.nn.softmax(topk_scores, axis=-1).cumsum(axis=-1) - score_mask = cumulative_probs < self.top_p - - # include the token that is higher than top_p as well - score_mask = jnp.roll(score_mask, 1) - score_mask |= score_mask.at[:, 0].set(True) - - # min tokens to keep - score_mask = score_mask.at[:, : self.min_tokens_to_keep].set(True) - - topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores) - next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1] - - return next_scores - - -class FlaxTopKLogitsWarper(FlaxLogitsWarper): - r""" - [`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. - - Args: - top_k (`int`): - The number of highest probability vocabulary tokens to keep for top-k-filtering. - filter_value (`float`, *optional*, defaults to -inf): - All filtered values will be set to this float value. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered. - """ - - def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): - if not isinstance(top_k, int) or top_k <= 0: - raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") - - self.top_k = max(top_k, min_tokens_to_keep) - self.filter_value = filter_value - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - batch_size, vocab_size = scores.shape - next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value) - - topk = min(self.top_k, scores.shape[-1]) # Safety check - topk_scores, topk_indices = lax.top_k(scores, topk) - shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten() - topk_scores_flat = topk_scores.flatten() - topk_indices_flat = topk_indices.flatten() + shift - - next_scores_flat = next_scores_flat.at[topk_indices_flat].set(topk_scores_flat) - next_scores = next_scores_flat.reshape(batch_size, vocab_size) - return next_scores - - -class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor): - r""" - [`FlaxLogitsProcessor`] that enforces the specified token as the first generated token. - - Args: - bos_token_id (`int`): - The id of the token to force as the first generated token. - """ - - def __init__(self, bos_token_id: int): - self.bos_token_id = bos_token_id - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - new_scores = jnp.full(scores.shape, -float("inf")) - - apply_penalty = 1 - jnp.bool_(cur_len - 1) - - scores = jnp.where(apply_penalty, new_scores.at[:, self.bos_token_id].set(0), scores) - - return scores - - -class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor): - r""" - [`FlaxLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached. - - Args: - max_length (`int`): - The maximum length of the sequence to be generated. - eos_token_id (`int`): - The id of the token to force as the last generated token when `max_length` is reached. - """ - - def __init__(self, max_length: int, eos_token_id: int): - self.max_length = max_length - self.eos_token_id = eos_token_id - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - new_scores = jnp.full(scores.shape, -float("inf")) - - apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1) - - scores = jnp.where(apply_penalty, new_scores.at[:, self.eos_token_id].set(0), scores) - - return scores - - -class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor): - r""" - [`FlaxLogitsProcessor`] enforcing a min-length by setting EOS probability to 0. - - Args: - min_length (`int`): - The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`int`): - The id of the *end-of-sequence* token. - """ - - def __init__(self, min_length: int, eos_token_id: int): - if not isinstance(min_length, int) or min_length < 0: - raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") - - if not isinstance(eos_token_id, int) or eos_token_id < 0: - raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") - - self.min_length = min_length - self.eos_token_id = eos_token_id - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - # create boolean flag to decide if min length penalty should be applied - apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1) - - scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores) - - return scores - - -class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor): - r""" - [`FlaxLogitsProcessor`] suppressing a list of tokens as soon as the `generate` function starts generating using - `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not sampled at the - beginning of the generation. - - Args: - begin_suppress_tokens (`list[int]`): - Tokens to not sample. - begin_index (`int`): - Index where the tokens are suppressed. - """ - - def __init__(self, begin_suppress_tokens, begin_index): - self.begin_suppress_tokens = list(begin_suppress_tokens) - self.begin_index = begin_index - - def __call__(self, input_ids, scores, cur_len: int): - apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index) - - scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores) - - return scores - - -class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor): - r""" - [`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs - to be `-inf` so they are not sampled. - - Args: - suppress_tokens (`list`): - Tokens to not sample. - """ - - def __init__(self, suppress_tokens: list): - self.suppress_tokens = list(suppress_tokens) - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - scores = scores.at[..., self.suppress_tokens].set(-float("inf")) - - return scores - - -class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor): - r""" - [`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to - token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens - to `-inf` so that they are sampled at their corresponding index. - - Args: - force_token_map (`list`): - Map giving token ids and indices where they will be forced to be sampled. - """ - - def __init__(self, force_token_map): - force_token_map = dict(force_token_map) - # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the - # index of the array corresponds to the index of the token to be forced, for XLA compatibility. - # Indexes without forced tokens will have a negative value. - force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1 - for index, token in force_token_map.items(): - if token is not None: - force_token_array = force_token_array.at[index].set(token) - self.force_token_array = jnp.int32(force_token_array) - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - def _force_token(generation_idx): - batch_size = scores.shape[0] - current_token = self.force_token_array[generation_idx] - - new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf") - updates = jnp.zeros((batch_size, 1), dtype=scores.dtype) - new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token)) - return new_scores - - scores = lax.cond( - cur_len >= self.force_token_array.shape[0], - # If the current length is geq than the length of force_token_array, the processor does nothing. - lambda: scores, - # Otherwise, it may force a certain token. - lambda: lax.cond( - self.force_token_array[cur_len] >= 0, - # Only valid (positive) tokens are forced - lambda: _force_token(cur_len), - # Otherwise, the processor does nothing. - lambda: scores, - ), - ) - return scores - - -class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor): - r""" - Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log - probs to `inf` so that they are sampled at their corresponding index. - - Args: - generate_config (`GenerateConfig`): - The generate config used to generate the output. The following parameters are required: - eos_token_id (`int`, *optional*, defaults to 50257): - The id of the *end-of-sequence* token. - no_timestamps_token_id (`int`, *optional*, defaults to 50363): - The id of the `"<|notimestamps|>"` token. - max_initial_timestamp_index (`int`, *optional*, defaults to 1): - Used to set the maximum value of the initial timestamp. This is used to prevent the model from - predicting timestamps that are too far in the future. - """ - - def __init__(self, generate_config, model_config, decoder_input_length): - self.eos_token_id = generate_config.eos_token_id - self.no_timestamps_token_id = generate_config.no_timestamps_token_id - self.timestamp_begin = generate_config.no_timestamps_token_id + 1 - - self.begin_index = decoder_input_length + 1 - - if generate_config.is_multilingual: - # room for language token and task token - self.begin_index += 2 - if hasattr(generate_config, "max_initial_timestamp_index"): - self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index - else: - self.max_initial_timestamp_index = model_config.vocab_size - if self.max_initial_timestamp_index is None: - self.max_initial_timestamp_index = model_config.vocab_size - - def __call__(self, input_ids, scores, cur_len): - # suppress <|notimestamps|> which is handled by without_timestamps - scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf")) - - def handle_pairs(input_ids_k, scores_k): - last_was_timestamp = jnp.where((cur_len - self.begin_index) >= 1, True, False) - last_was_timestamp = jnp.where( - input_ids_k[cur_len - 1] >= self.timestamp_begin, - True and last_was_timestamp, - False, - ) - - penultimate_was_timestamp = jnp.where((cur_len - self.begin_index) < 2, True, False) - penultimate_was_timestamp = jnp.where( - input_ids_k[cur_len - 2] >= self.timestamp_begin, - True, - penultimate_was_timestamp, - ) - - return jnp.where( - last_was_timestamp, - jnp.where( - penultimate_was_timestamp > 0, - scores_k.at[self.timestamp_begin :].set(-float("inf")), - scores_k.at[: self.eos_token_id].set(-float("inf")), - ), - scores_k, - ) - - scores = jax.vmap(handle_pairs)(input_ids, scores) - - apply_max_initial_timestamp = jnp.where(cur_len == self.begin_index, True, False) - apply_max_initial_timestamp = jnp.where( - self.max_initial_timestamp_index is not None, - True and apply_max_initial_timestamp, - False, - ) - - last_allowed = self.timestamp_begin + self.max_initial_timestamp_index - - scores = jnp.where( - apply_max_initial_timestamp, - scores.at[:, last_allowed + 1 :].set(-float("inf")), - scores, - ) - - # if sum of probability over timestamps is above any other token, sample timestamp - logprobs = jax.nn.log_softmax(scores, axis=-1) - - def handle_cumulative_probs(logprobs_k, scores_k): - timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1) - max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin]) - return jnp.where( - timestamp_logprob > max_text_token_logprob, - scores_k.at[: self.timestamp_begin].set(-float("inf")), - scores_k, - ) - - scores = jax.vmap(handle_cumulative_probs)(logprobs, scores) - - return scores - - -class FlaxNoRepeatNGramLogitsProcessor(FlaxLogitsProcessor): - r""" - [`FlaxLogitsProcessor`] that enforces no repetition of n-grams. See - [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). - - Args: - ngram_size (`int`): - All ngrams of size `ngram_size` can only occur once. - """ - - def __init__(self, ngram_size: int): - if not isinstance(ngram_size, int) or ngram_size <= 0: - raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") - self.ngram_size = ngram_size - - def get_previous_ngrams(self, input_ids: jnp.ndarray, vocab_size: int, cur_len: int): - """ - get a matrix of size (batch_size,) + (vocab_size,)*n (for n-grams) that - represent the n-grams that occurred previously. - The BCOO representation allow to store only the few non-zero entries, instead of the full (huge) matrix - """ - batch_size, seq_len = input_ids.shape - # number of n-grams in the whole sequence - seq_ngrams = seq_len - (self.ngram_size - 1) - # number of n-grams in the currently generated sequence - cur_ngrams = cur_len - (self.ngram_size - 1) - - def body_fun(i, val): - b = i % batch_size - pos = i // batch_size - return val.at[i].set( - jnp.array( - [ - b, - ] - + [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)] - ) - ) - - shape = (batch_size * seq_ngrams, self.ngram_size + 1) - all_update_indices = jax.lax.fori_loop( - 0, batch_size * cur_ngrams, body_fun, jnp.zeros(shape, dtype=input_ids.dtype) - ) - - # ignore the n-grams not yet generated - data = (jnp.arange(batch_size * seq_ngrams) < batch_size * cur_ngrams).astype("float32") - - return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size) - - def get_banned_tokens_mask(self, latest_tokens: jnp.ndarray, previous_ngrams) -> jnp.ndarray: - """ - Determines which tokens must be banned given latest tokens and the previously seen - ngrams. - """ - - @sparse.sparsify - @jax.vmap - def inner_fn(latest_tokens, previous_ngrams): - return previous_ngrams[tuple(latest_tokens)] - - return sparse.bcoo_todense(inner_fn(latest_tokens, previous_ngrams)) - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - def true_fn(): - _, vocab_size = scores.shape - # store the previously seen n-grams - previous_ngrams = self.get_previous_ngrams(input_ids, vocab_size, cur_len) - - # get the n-1 last tokens that prefix the n-gram being generated - latest_tokens = jnp.zeros((input_ids.shape[0], self.ngram_size - 1), dtype=input_ids.dtype) - latest_tokens = jax.lax.dynamic_update_slice( - latest_tokens, - jax.lax.dynamic_slice( - input_ids, (0, cur_len - (self.ngram_size - 1)), (input_ids.shape[0], (self.ngram_size - 1)) - ), - (0, 0), - ) - - # compute the banned tokens, ie all the tokens that when added to the latest tokens lead to a n-gram that was previously generated - banned_tokens_indices_mask = self.get_banned_tokens_mask(latest_tokens, previous_ngrams).astype("bool") - return jnp.where(banned_tokens_indices_mask, -float("inf"), scores) - - output = jax.lax.cond((cur_len >= self.ngram_size - 1), true_fn, lambda: scores) - return output diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py deleted file mode 100644 index e858a9813cea..000000000000 --- a/src/transformers/generation/flax_utils.py +++ /dev/null @@ -1,1032 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Google AI Flax Team Authors, and The HuggingFace Inc. team. -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import copy -import inspect -import warnings -from functools import partial -from typing import Any, Optional, Union - -import flax -import jax -import jax.numpy as jnp -import numpy as np -from jax import lax - -from ..models.auto import ( - FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, - FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, -) -from ..utils import ModelOutput, logging -from .configuration_utils import GenerationConfig -from .flax_logits_process import ( - FlaxForcedBOSTokenLogitsProcessor, - FlaxForcedEOSTokenLogitsProcessor, - FlaxForceTokensLogitsProcessor, - FlaxLogitsProcessorList, - FlaxMinLengthLogitsProcessor, - FlaxNoRepeatNGramLogitsProcessor, - FlaxSuppressTokensAtBeginLogitsProcessor, - FlaxSuppressTokensLogitsProcessor, - FlaxTemperatureLogitsWarper, - FlaxTopKLogitsWarper, - FlaxTopPLogitsWarper, -) - - -logger = logging.get_logger(__name__) - - -@flax.struct.dataclass -class FlaxGreedySearchOutput(ModelOutput): - """ - Flax Base class for outputs of decoder-only generation models using greedy search. - - - Args: - sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): - The generated sequences. - """ - - sequences: Optional[jnp.ndarray] = None - - -@flax.struct.dataclass -class FlaxSampleOutput(ModelOutput): - """ - Flax Base class for outputs of decoder-only generation models using sampling. - - - Args: - sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): - The generated sequences. - """ - - sequences: Optional[jnp.ndarray] = None - - -@flax.struct.dataclass -class FlaxBeamSearchOutput(ModelOutput): - """ - Flax Base class for outputs of decoder-only generation models using greedy search. - - - Args: - sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): - The generated sequences. - scores (`jnp.ndarray` of shape `(batch_size,)`): - The scores (log probabilities) of the generated sequences. - """ - - sequences: Optional[jnp.ndarray] = None - scores: Optional[jnp.ndarray] = None - - -@flax.struct.dataclass -class GreedyState: - cur_len: jnp.ndarray - sequences: jnp.ndarray - running_token: jnp.ndarray - is_sent_finished: jnp.ndarray - model_kwargs: dict[str, jnp.ndarray] - - -@flax.struct.dataclass -class SampleState: - cur_len: jnp.ndarray - sequences: jnp.ndarray - running_token: jnp.ndarray - is_sent_finished: jnp.ndarray - prng_key: jnp.ndarray - model_kwargs: dict[str, jnp.ndarray] - - -@flax.struct.dataclass -class BeamSearchState: - cur_len: jnp.ndarray - running_sequences: jnp.ndarray - running_scores: jnp.ndarray - sequences: jnp.ndarray - scores: jnp.ndarray - is_sent_finished: jnp.ndarray - model_kwargs: dict[str, jnp.ndarray] - - -class FlaxGenerationMixin: - """ - A class containing all functions for auto-regressive text generation, to be used as a mixin in - [`FlaxPreTrainedModel`]. - - The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for: - - *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and - `do_sample=False` - - *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and - `do_sample=True` - - *beam-search decoding* by calling [`~generation.FlaxGenerationMixin._beam_search`] if `num_beams>1` and - `do_sample=False` - - You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To - learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). - """ - - def prepare_inputs_for_generation(self, *args, **kwargs): - raise NotImplementedError( - "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." - ) - - @staticmethod - def _run_loop_in_debug(cond_fn, body_fn, init_state): - """ - Run generation in untraced mode. This should only be used for debugging purposes. - """ - state = init_state - while cond_fn(state): - state = body_fn(state) - return state - - def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs): - encoder_kwargs = { - argument: value - for argument, value in model_kwargs.items() - if not (argument.startswith("decoder_") or argument.startswith("cross_attn")) - } - model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs) - return model_kwargs - - def _prepare_decoder_input_ids_for_generation( - self, - batch_size: int, - decoder_start_token_id: Optional[int] = None, - bos_token_id: Optional[int] = None, - model_kwargs: Optional[dict[str, jnp.ndarray]] = None, - ) -> jnp.ndarray: - if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - # Only use this arg if not None, otherwise just remove from model_kwargs - decoder_input_ids = model_kwargs.pop("decoder_input_ids") - if decoder_input_ids is not None: - return decoder_input_ids - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) - return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0) - - def _get_decoder_start_token_id( - self, decoder_start_token_id: Optional[int] = None, bos_token_id: Optional[int] = None - ) -> int: - # retrieve decoder_start_token_id for encoder-decoder models - # fall back to bos_token_id if necessary - decoder_start_token_id = ( - decoder_start_token_id - if decoder_start_token_id is not None - else self.generation_config.decoder_start_token_id - ) - bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id - if decoder_start_token_id is not None: - return decoder_start_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "decoder_start_token_id") - and self.config.decoder.decoder_start_token_id is not None - ): - return self.config.decoder.decoder_start_token_id - elif bos_token_id is not None: - return bos_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "bos_token_id") - and self.config.decoder.bos_token_id is not None - ): - return self.config.decoder.bos_token_id - raise ValueError( - "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." - ) - - @staticmethod - def _expand_to_num_beams(tensor, num_beams): - return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]) - - def _adapt_logits_for_beam_search(self, logits): - """ - This function can be overwritten in the specific modeling_flax_.py classes to allow for custom beam - search behavior. Note that the only model that overwrites this method is [`~transformers.FlaxMarianMTModel`]. - """ - return logits - - def _validate_model_class(self): - """ - Confirms that the model class is compatible with generation. If not, raises an exception that points to the - right class to use. - """ - if not self.can_generate(): - generate_compatible_mappings = [ - FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, - FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, - FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - ] - generate_compatible_classes = set() - for model_mapping in generate_compatible_mappings: - supported_models = model_mapping.get(type(self.config), default=None) - if supported_models is not None: - generate_compatible_classes.add(supported_models.__name__) - exception_message = ( - f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " - "it doesn't have a language model head." - ) - if generate_compatible_classes: - exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" - raise TypeError(exception_message) - - def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): - """Validates model kwargs for generation. Generate argument typos will also be caught here.""" - unused_model_args = [] - model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) - # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If - # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) - if "kwargs" in model_args or "model_kwargs" in model_args: - model_args |= set(inspect.signature(self.__call__).parameters) - for key, value in model_kwargs.items(): - if value is not None and key not in model_args: - unused_model_args.append(key) - - if unused_model_args: - raise ValueError( - f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" - " generate arguments will also show up in this list)" - ) - - def generate( - self, - input_ids: jnp.ndarray, - generation_config: Optional[GenerationConfig] = None, - prng_key: Optional[jnp.ndarray] = None, - trace: bool = True, - params: Optional[dict[str, jnp.ndarray]] = None, - logits_processor: Optional[FlaxLogitsProcessorList] = None, - **kwargs, - ): - r""" - Generates sequences of token ids for models with a language modeling head. - - Parameters: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - generation_config (`~generation.GenerationConfig`, *optional*): - The generation configuration to be used as base parametrization for the generation call. `**kwargs` - passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which had the following loading - priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model - configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s - default values, whose documentation should be checked to parameterize generation. - trace (`bool`, *optional*, defaults to `True`): - Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a - considerably slower runtime. - params (`dict[str, jnp.ndarray]`, *optional*): - Optionally the model parameters can be passed. Can be useful for parallelized generation. - logits_processor (`FlaxLogitsProcessorList `, *optional*): - Custom logits processors that complement the default logits processors built from arguments and - generation config. If a logit processor is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - kwargs (`dict[str, Any]`, *optional*): - Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be - forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder - specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. - - Return: - [`~utils.ModelOutput`]. - - """ - # Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - self._validate_model_class() - - # priority: `generation_config` argument > `model.generation_config` (the default generation config) - if generation_config is None: - # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, - # two conditions must be met - # 1) the generation config must have been created from the model config (`_from_model_config` field); - # 2) the generation config must have seen no modification since its creation (the hash is the same). - if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash( - self.generation_config - ): - new_generation_config = GenerationConfig.from_model_config(self.config) - if new_generation_config != self.generation_config: - warnings.warn( - "You have modified the pretrained model configuration to control generation. This is a" - " deprecated strategy to control generation and will be removed soon, in a future version." - " Please use and modify the model generation configuration (see" - " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" - ) - self.generation_config = new_generation_config - generation_config = self.generation_config - - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - self._validate_model_kwargs(model_kwargs.copy()) - - logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList() - - # set init values - prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) - - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask") is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - generation_config.pad_token_id = eos_token_id - - if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder: - raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.") - - # decoder-only models should use left-padding for generation (can't be checked with `trace=True`) - if not self.config.is_encoder_decoder and not trace: - if ( - generation_config.pad_token_id is not None - and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0 - ): - logger.warning( - "A decoder-only architecture is being used, but right-padding was detected! For correct " - "generation results, please set `padding_side='left'` when initializing the tokenizer." - ) - - batch_size = input_ids.shape[0] - - if self.config.is_encoder_decoder: - # add encoder_outputs to model_kwargs - if model_kwargs.get("encoder_outputs") is None: - model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) - # prepare decoder_input_ids for generation - input_ids = self._prepare_decoder_input_ids_for_generation( - batch_size, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, - model_kwargs=model_kwargs, - ) - - # Prepare `max_length` depending on other stopping criteria. - input_ids_seq_length = input_ids.shape[-1] - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: - # 20 is the default max_length of the generation config - warnings.warn( - f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " - "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - if not has_default_max_length and generation_config.max_length is not None: - logger.warning( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - else: # by default let's always generate 20 new tokens - if generation_config.max_length == GenerationConfig().max_length: - generation_config.max_length = generation_config.max_length + input_ids_seq_length - max_position_embeddings = getattr(self.config, "max_position_embeddings", None) - if max_position_embeddings is not None: - generation_config.max_length = min(generation_config.max_length, max_position_embeddings) - - if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: - raise ValueError( - f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than" - f" the maximum length ({generation_config.max_length})" - ) - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing`max_new_tokens`." - ) - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - logits_processor=logits_processor, - ) - - if not generation_config.do_sample and generation_config.num_beams == 1: - return self._greedy_search( - input_ids, - generation_config.max_length, - generation_config.pad_token_id, - generation_config.eos_token_id, - logits_processor=logits_processor, - trace=trace, - params=params, - model_kwargs=model_kwargs, - ) - elif generation_config.do_sample and generation_config.num_beams == 1: - logits_warper = self._get_logits_warper(generation_config=generation_config) - return self._sample( - input_ids, - generation_config.max_length, - generation_config.pad_token_id, - generation_config.eos_token_id, - prng_key, - logits_warper=logits_warper, - logits_processor=logits_processor, - trace=trace, - params=params, - model_kwargs=model_kwargs, - ) - elif not generation_config.do_sample and generation_config.num_beams > 1: - # broadcast input_ids & encoder_outputs - input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams) - - if "encoder_outputs" in model_kwargs: - model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( - model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams - ) - - for kwarg in ["attention_mask", "decoder_attention_mask"]: - if kwarg in model_kwargs: - model_kwargs[kwarg] = self._expand_to_num_beams( - model_kwargs[kwarg], num_beams=generation_config.num_beams - ) - - return self._beam_search( - input_ids, - generation_config.max_length, - generation_config.pad_token_id, - generation_config.eos_token_id, - length_penalty=generation_config.length_penalty, - early_stopping=generation_config.early_stopping, - logits_processor=logits_processor, - trace=trace, - params=params, - num_return_sequences=generation_config.num_return_sequences, - model_kwargs=model_kwargs, - ) - else: - raise NotImplementedError("`Beam sampling is currently not implemented.") - - def _get_logits_warper(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList: - """ - This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`] - instances used for multinomial sampling. - """ - warpers = FlaxLogitsProcessorList() - - if generation_config.temperature is not None and generation_config.temperature != 1.0: - warpers.append(FlaxTemperatureLogitsWarper(generation_config.temperature)) - if generation_config.top_k is not None and generation_config.top_k != 0: - warpers.append(FlaxTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1)) - if generation_config.top_p is not None and generation_config.top_p < 1.0: - warpers.append(FlaxTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1)) - - return warpers - - def _get_logits_processor( - self, - generation_config: GenerationConfig, - input_ids_seq_length: int, - logits_processor: Optional[FlaxLogitsProcessorList], - ) -> FlaxLogitsProcessorList: - """ - This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`] - instances used to modify the scores of the language model head. - """ - processors = FlaxLogitsProcessorList() - - if ( - generation_config.min_length is not None - and generation_config.eos_token_id is not None - and generation_config.min_length > -1 - ): - processors.append( - FlaxMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id) - ) - if generation_config.forced_bos_token_id is not None: - processors.append(FlaxForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) - if generation_config.forced_eos_token_id is not None: - processors.append( - FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) - ) - if generation_config.suppress_tokens is not None: - processors.append(FlaxSuppressTokensLogitsProcessor(generation_config.suppress_tokens)) - if generation_config.begin_suppress_tokens is not None: - begin_index = input_ids_seq_length - begin_index = ( - begin_index - if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) - else begin_index + 1 - ) - if ( - getattr(generation_config, "forced_decoder_ids", None) is not None - and len(generation_config.forced_decoder_ids) > 0 - ): - # generation starts after the last token that is forced - begin_index += generation_config.forced_decoder_ids[-1][0] - processors.append( - FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) - ) - if getattr(generation_config, "forced_decoder_ids", None) is not None: - forced_decoder_ids = [ - [input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids - ] - processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids)) - if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: - processors.append(FlaxNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) - processors = self._merge_criteria_processor_list(processors, logits_processor) - - return processors - - def _merge_criteria_processor_list( - self, - default_list: FlaxLogitsProcessorList, - custom_list: FlaxLogitsProcessorList, - ) -> FlaxLogitsProcessorList: - if len(custom_list) == 0: - return default_list - for default in default_list: - for custom in custom_list: - if type(custom) is type(default): - object_type = "logits processor" - raise ValueError( - f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" - f" `generate`, but it has already been created with the values {default}. {default} has been" - " created by passing the corresponding arguments to generate or by the model's config default" - f" values. If you just want to change the default values of {object_type} consider passing" - f" them as arguments to `generate` instead of using a custom {object_type}." - ) - default_list.extend(custom_list) - return default_list - - def _greedy_search( - self, - input_ids: None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - logits_processor: Optional[FlaxLogitsProcessorList] = None, - trace: bool = True, - params: Optional[dict[str, jnp.ndarray]] = None, - model_kwargs: Optional[dict[str, jnp.ndarray]] = None, - ): - # init values - max_length = max_length if max_length is not None else self.generation_config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - - batch_size, cur_len = input_ids.shape - - eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) - pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) - cur_len = jnp.array(cur_len) - - # per batch-item holding current token in loop. - sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) - sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) - - # per batch-item state bit indicating if sentence has finished. - is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) - - # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop - # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. - model = self.decode if self.config.is_encoder_decoder else self - # initialize model specific kwargs - model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) - - # initialize state - state = GreedyState( - cur_len=cur_len, - sequences=sequences, - running_token=input_ids, - is_sent_finished=is_sent_finished, - model_kwargs=model_kwargs, - ) - - def greedy_search_cond_fn(state): - """state termination condition fn.""" - has_reached_max_length = state.cur_len == max_length - all_sequence_finished = jnp.all(state.is_sent_finished) - finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished) - return ~finish_generation - - def greedy_search_body_fn(state): - """state update fn.""" - model_outputs = model(state.running_token, params=params, **state.model_kwargs) - logits = model_outputs.logits[:, -1] - - # apply min_length, ... - logits = logits_processor(state.sequences, logits, state.cur_len) - - next_token = jnp.argmax(logits, axis=-1) - - next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished - next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) - next_token = next_token[:, None] - - next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) - next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) - return GreedyState( - cur_len=state.cur_len + 1, - sequences=next_sequences, - running_token=next_token, - is_sent_finished=next_is_sent_finished, - model_kwargs=next_model_kwargs, - ) - - # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU - if input_ids.shape[1] > 1: - state = greedy_search_body_fn(state) - - if not trace: - state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state) - else: - state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state) - - return FlaxGreedySearchOutput(sequences=state.sequences) - - def _sample( - self, - input_ids: None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - prng_key: Optional[jnp.ndarray] = None, - logits_processor: Optional[FlaxLogitsProcessorList] = None, - logits_warper: Optional[FlaxLogitsProcessorList] = None, - trace: bool = True, - params: Optional[dict[str, jnp.ndarray]] = None, - model_kwargs: Optional[dict[str, jnp.ndarray]] = None, - ): - # init values - max_length = max_length if max_length is not None else self.generation_config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) - - batch_size, cur_len = input_ids.shape - - eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) - pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) - cur_len = jnp.array(cur_len) - - # per batch-item holding current token in loop. - sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) - sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) - - # per batch-item state bit indicating if sentence has finished. - is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) - - # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop - # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. - model = self.decode if self.config.is_encoder_decoder else self - - # initialize model specific kwargs - model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) - - # initialize state - state = SampleState( - cur_len=cur_len, - sequences=sequences, - running_token=input_ids, - is_sent_finished=is_sent_finished, - prng_key=prng_key, - model_kwargs=model_kwargs, - ) - - def sample_search_cond_fn(state): - """state termination condition fn.""" - has_reached_max_length = state.cur_len == max_length - all_sequence_finished = jnp.all(state.is_sent_finished) - finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished) - return ~finish_generation - - def sample_search_body_fn(state): - """state update fn.""" - prng_key, prng_key_next = jax.random.split(state.prng_key) - model_outputs = model(state.running_token, params=params, **state.model_kwargs) - - logits = model_outputs.logits[:, -1] - - # apply min_length, ... - logits = logits_processor(state.sequences, logits, state.cur_len) - # apply top_p, top_k, temperature - logits = logits_warper(logits, logits, state.cur_len) - - next_token = jax.random.categorical(prng_key, logits, axis=-1) - - next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished - next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) - next_token = next_token[:, None] - - next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) - next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) - - return SampleState( - cur_len=state.cur_len + 1, - sequences=next_sequences, - running_token=next_token, - is_sent_finished=next_is_sent_finished, - model_kwargs=next_model_kwargs, - prng_key=prng_key_next, - ) - - # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU - if input_ids.shape[1] > 1: - state = sample_search_body_fn(state) - - if not trace: - state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state) - else: - state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state) - - return FlaxSampleOutput(sequences=state.sequences) - - def _beam_search( - self, - input_ids: None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - length_penalty: Optional[float] = None, - early_stopping: Optional[Union[bool, str]] = None, - logits_processor: Optional[FlaxLogitsProcessorList] = None, - trace: bool = True, - params: Optional[dict[str, jnp.ndarray]] = None, - num_return_sequences: Optional[int] = None, - model_kwargs: Optional[dict[str, jnp.ndarray]] = None, - ): - """ - This beam search function is heavily inspired by Flax's official example: - https://github.com/google/flax/blob/main/examples/wmt/decode.py - """ - - def flatten_beam_dim(tensor): - """Flattens the first two dimensions of a non-scalar array.""" - # ignore scalars (e.g. cache index) - if tensor.ndim == 0: - return tensor - return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) - - def unflatten_beam_dim(tensor, batch_size, num_beams): - """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" - # ignore scalars (e.g. cache index) - if tensor.ndim == 0: - return tensor - return tensor.reshape((batch_size, num_beams) + tensor.shape[1:]) - - def gather_beams(nested, beam_indices, batch_size, new_num_beams): - """ - Gathers the beam slices indexed by beam_indices into new beam array. - """ - batch_indices = jnp.reshape( - jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams) - ) - - def gather_fn(tensor): - # ignore scalars (e.g. cache index) - if tensor.ndim == 0: - return tensor - else: - return tensor[batch_indices, beam_indices] - - return jax.tree_util.tree_map(gather_fn, nested) - - # init values - max_length = max_length if max_length is not None else self.generation_config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty - early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping - num_return_sequences = ( - num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences - ) - - batch_size, num_beams, cur_len = input_ids.shape - - eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) - pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) - cur_len = jnp.array(cur_len) - - # record the prompt length of decoder - decoder_prompt_len = input_ids.shape[-1] - - # per batch,beam-item holding current token in loop. - sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) - running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) - running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0)) - - # per batch,beam-item state bit indicating if sentence has finished. - is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_) - - # per batch,beam-item score, logprobs - running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1]) - scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7) - - # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop - # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. - model = self.decode if self.config.is_encoder_decoder else self - - # flatten beam dim - if "encoder_outputs" in model_kwargs: - model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( - model_kwargs["encoder_outputs"]["last_hidden_state"] - ) - for kwarg in ["attention_mask", "decoder_attention_mask"]: - if kwarg in model_kwargs: - model_kwargs[kwarg] = flatten_beam_dim(model_kwargs[kwarg]) - - # initialize model specific kwargs - model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs) - - # initialize state - state = BeamSearchState( - cur_len=cur_len, - running_sequences=running_sequences, - running_scores=running_scores, - sequences=sequences, - scores=scores, - is_sent_finished=is_sent_finished, - model_kwargs=model_kwargs, - ) - - def beam_search_cond_fn(state): - """beam search state termination condition fn.""" - - # 1. is less than max length? - not_max_length_yet = state.cur_len < max_length - - # 2. can the new beams still improve? - # early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion - # below for more details. - # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 - # early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of - # length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there. - if early_stopping == "never" and length_penalty > 0.0: - best_running_score = state.running_scores[:, :1] / ( - (max_length - decoder_prompt_len) ** length_penalty - ) - else: - best_running_score = state.running_scores[:, :1] / ( - (state.cur_len - decoder_prompt_len) ** length_penalty - ) - worst_finished_score = jnp.where( - state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) - ) - improvement_still_possible = jnp.any(best_running_score > worst_finished_score) - - # 3. is there still a beam that has not finished? - still_open_beam = ~(jnp.all(state.is_sent_finished) & (early_stopping is True)) - - return not_max_length_yet & still_open_beam & improvement_still_possible - - def beam_search_body_fn(state, input_ids_length=1): - """beam search state update fn.""" - # 1. Forward current tokens - # Collect the current position slice along length to feed the fast - # autoregressive decoder model. Flatten the beam dimension into batch - # dimension for feeding into the model. - # unflatten beam dimension - # Unflatten beam dimension in attention cache arrays - input_token = flatten_beam_dim( - lax.dynamic_slice( - state.running_sequences, - (0, 0, state.cur_len - input_ids_length), - (batch_size, num_beams, input_ids_length), - ) - ) - model_outputs = model(input_token, params=params, **state.model_kwargs) - - logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) - cache = jax.tree_util.tree_map( - lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values - ) - - # adapt logits for FlaxMarianMTModel - logits = self._adapt_logits_for_beam_search(logits) - - # 2. Compute log probs - # get log probabilities from logits, - # process logits with processors (*e.g.* min_length, ...), and - # add new logprobs to existing running logprobs scores. - log_probs = jax.nn.log_softmax(logits) - log_probs = logits_processor( - flatten_beam_dim(state.running_sequences), flatten_beam_dim(log_probs), state.cur_len - ) - log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) - log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) - vocab_size = log_probs.shape[2] - log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) - - # 3. Retrieve top-K - # Each item in batch has num_beams * vocab_size candidate sequences. - # For each item, get the top 2*k candidates with the highest log- - # probabilities. We gather the top 2*K beams here so that even if the best - # K sequences reach EOS simultaneously, we have another K sequences - # remaining to continue the live beam search. - # Gather the top 2*K scores from _all_ beams. - # Gather 2*k top beams. - # Recover the beam index by floor division. - # Recover token id by modulo division and expand Id array for broadcasting. - # Update sequences for the 2*K top-k new sequences. - beams_to_keep = 2 * num_beams - topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) - topk_beam_indices = topk_indices // vocab_size - topk_running_sequences = gather_beams( - state.running_sequences, topk_beam_indices, batch_size, beams_to_keep - ) - topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) - topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) - - # 4. Check which sequences have ended - # Update current sequences: - # Did any of these sequences reach an end marker? - # To prevent these just finished sequences from being added to the current sequences - # set of active beam search sequences, set their log probs to a very large - # negative value. - did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id - running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7) - # 5. Get running sequences scores for next - # Determine the top k beam indices (from top 2*k beams) from log probs - # and gather top k beams (from top 2*k beams). - next_topk_indices = lax.top_k(running_topk_log_probs, k=num_beams)[1] - next_running_sequences, next_running_scores = gather_beams( - [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams - ) - - # 6. Process topk logits - # Further process log probs: - # - add length penalty - # - make sure no scores can be added anymore if beam is full - # - make sure still running sequences cannot be chosen as finalized beam - topk_log_probs = topk_log_probs / ((state.cur_len + 1 - decoder_prompt_len) ** length_penalty) - beams_in_batch_are_full = jnp.broadcast_to( - state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape - ) & (early_stopping is True) - add_penalty = ~did_topk_just_finished | beams_in_batch_are_full - topk_log_probs += add_penalty * np.array(-1.0e7) - - # 7. Get scores, sequences, is sentence finished for next. - # Combine sequences, scores, and flags along the beam dimension and compare - # new finished sequence scores to existing finished scores and select the - # best from the new set of beams - merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1) - merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) - merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1) - topk_merged_indices = lax.top_k(merged_scores, k=num_beams)[1] - next_sequences, next_scores, next_is_sent_finished = gather_beams( - [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams - ) - - # 8. Update model kwargs. - # Determine the top k beam indices from the original set of all beams. - # With these, gather the top k beam-associated caches. - next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) - next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) - model_outputs["past_key_values"] = jax.tree_util.tree_map(lambda x: flatten_beam_dim(x), next_cache) - next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) - - return BeamSearchState( - cur_len=state.cur_len + 1, - running_scores=next_running_scores, - running_sequences=next_running_sequences, - scores=next_scores, - sequences=next_sequences, - is_sent_finished=next_is_sent_finished, - model_kwargs=next_model_kwargs, - ) - - # Always run first iteration outside of `lax.while_loop` to avoid calling `beam_search_cond_fn` - # when `state.cur_len` equals `decoder_prompt_len`. This also helps to comply with TPU when - # the very first prompt has sequence length > 1. - state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state) - - if not trace: - state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state) - else: - state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state) - - # Account for the edge-case where there are no finished sequences for a - # particular batch item. If so, return running sequences for that batch item. - none_finished = jnp.any(state.is_sent_finished, axis=1) - sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences) - scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) - - # Take best beams for each batch (the score is sorted in descending order) - sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :]) - scores = flatten_beam_dim(scores[:, :num_return_sequences]) - - return FlaxBeamSearchOutput(sequences=sequences, scores=scores) diff --git a/src/transformers/generation/tf_logits_process.py b/src/transformers/generation/tf_logits_process.py deleted file mode 100644 index 436793c402ea..000000000000 --- a/src/transformers/generation/tf_logits_process.py +++ /dev/null @@ -1,600 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect - -import numpy as np -import tensorflow as tf - -from ..tf_utils import stable_softmax -from ..utils import add_start_docstrings -from ..utils.logging import get_logger - - -logger = get_logger(__name__) - - -TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`): - Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam - search or log softmax for each vocabulary token when using beam search. - cur_len (`int`): - The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length - is the maximum length generate can produce, and we need to know which of its tokens are valid. - kwargs (`dict[str, Any]`, *optional*): - Additional logits processor specific kwargs. - - Return: - `tf.Tensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. -""" - - -class TFLogitsProcessor: - """Abstract base class for all logit processors that can be applied during generation.""" - - @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - """TF method for processing logits.""" - raise NotImplementedError( - f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." - ) - - -class TFLogitsWarper: - """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" - - @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - """TF method for warping logits.""" - raise NotImplementedError( - f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." - ) - - -class TFLogitsProcessorList(list): - """ - This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor. - This class inherits from list and adds a specific *__call__* method to apply each [`TFLogitsProcessor`] to the - inputs. - """ - - @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor: - for processor in self: - function_args = inspect.signature(processor.__call__).parameters - if len(function_args) > 3: - if not all(arg in kwargs for arg in list(function_args.keys())[2:]): - raise ValueError( - f"Make sure that all the required parameters: {list(function_args.keys())} for " - f"{processor.__class__} are passed to the logits processor." - ) - scores = processor(input_ids, scores, cur_len, **kwargs) - else: - scores = processor(input_ids, scores, cur_len) - return scores - - -class TFTemperatureLogitsWarper(TFLogitsWarper): - r""" - [`TFLogitsWarper`] for temperature (exponential scaling output probability distribution). - - Args: - temperature (`float`): - The value used to module the logits distribution. - """ - - def __init__(self, temperature: float): - if not isinstance(temperature, float) or not (temperature > 0): - raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") - - self.temperature = temperature - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - scores = scores / self.temperature - return scores - - -class TFTopKLogitsWarper(TFLogitsWarper): - r""" - [`TFLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. - - Args: - top_k (`int`): - The number of highest probability vocabulary tokens to keep for top-k-filtering. - filter_value (`float`, *optional*, defaults to -inf): - All filtered values will be set to this float value. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered. - """ - - def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): - if not isinstance(top_k, int) or top_k <= 0: - raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") - - self.top_k = max(top_k, min_tokens_to_keep) - self.filter_value = filter_value - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - top_k = min(self.top_k, scores.shape[-1]) # Safety check - # Boolean mask containing all tokens with a probability less than the last token of the top-k - indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:] - next_scores = tf.where(indices_to_remove, self.filter_value, scores) - return next_scores - - -class TFTopPLogitsWarper(TFLogitsWarper): - """ - [`TFLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to <= prob_cut_off. - - Args: - top_p (`float`): - If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or - higher are kept for generation. - filter_value (`float`, *optional*, defaults to -inf): - All filtered values will be set to this float value. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered. - """ - - def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): - if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): - raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") - if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1): - raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}") - - self.top_p = top_p - self.filter_value = filter_value - self.min_tokens_to_keep = min_tokens_to_keep - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1]) - - mask_scores = tf.fill(scores.shape, self.filter_value) - cumulative_probs = tf.math.cumsum(stable_softmax(topk_scores, axis=-1), axis=-1) - score_mask = cumulative_probs < self.top_p - - # Also include the token that is higher than top_p (the first false = shift and insert a True on the left) - score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1) - - # Ensure min tokens to keep - score_mask = tf.concat( - ( - tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool), - score_mask[:, self.min_tokens_to_keep :], - ), - axis=-1, - ) - - # Mask the values that do not fit the criteria - topk_next_scores = tf.where(score_mask, topk_scores, mask_scores) - - # Undo the topk sorting: converts the 2D matrix of per-row original indices of shape (batch_size, vocab_size) - # to a 3D tensor of shape (batch_size, vocab_size, 2) containing the original score coordinate, from which we - # can scatter (i.e. `scatter_indices[row, col, :]` is a tensor containing `[row, topk_indices[row, col]]`) - scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]]) - scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1) - next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape) - - return next_scores - - -class TFMinLengthLogitsProcessor(TFLogitsProcessor): - r""" - [`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0. - - Args: - min_length (`int`): - The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`int`): - The id of the *end-of-sequence* token. - """ - - def __init__(self, min_length: int, eos_token_id: int): - if not isinstance(min_length, int) or min_length < 0: - raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") - - if not isinstance(eos_token_id, int) or eos_token_id < 0: - raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") - - self.min_length = min_length - self.eos_token_id = eos_token_id - - def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor: - eos_token_id_mask = tf.range(scores.shape[-1]) == self.eos_token_id - scores = tf.where(eos_token_id_mask, float("-inf"), scores) - return scores - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - # applies eos token masking if the first argument is true - scores = tf.cond( - tf.less(cur_len, self.min_length), - lambda: self._apply_eos_token_mask(scores), - lambda: tf.identity(scores), - ) - return scores - - -class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor): - r""" - [`TFLogitsProcessor`] enforcing an exponential penalty on repeated sequences. - - Args: - repetition_penalty (`float`): - The parameter for repetition penalty. 1.0 means no penalty. See [this - paper](https://huggingface.co/papers/1909.05858) for more details. - """ - - def __init__(self, penalty: float): - if not isinstance(penalty, float) or not (penalty > 0): - raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") - - self.penalty = penalty - - def _create_score_penalties(self, input_ids: tf.Tensor, logits: tf.Tensor) -> tf.Tensor: - # We want to populate the penalties in the positions of `input_ids`. Since XLA can't handle shapes unknown - # before runtime, `tf.unique` can't be used. Therefore, we may have redundant updates, when a given row has - # the same token multiple times. - - # Gathers the penalties to apply - logit_penalties = tf.gather(logits, input_ids, axis=1, batch_dims=1) - logit_penalties = tf.where(logit_penalties > 0, 1 / self.penalty, logit_penalties) - logit_penalties = tf.where(logit_penalties < 0, self.penalty, logit_penalties) - - # Scatters the penalties - token_penalties = tf.ones(logits.shape) - batch_size = input_ids.shape[0] - seq_len = tf.shape(input_ids)[1] # the sequence length has dynamic size, hence the dynamic shape - indexable_prev_input_ids = tf.concat( - ( - tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1), - tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1), - ), - axis=1, - ) - token_penalties = tf.tensor_scatter_nd_update( - token_penalties, indices=indexable_prev_input_ids, updates=tf.reshape(logit_penalties, [-1]) - ) - return token_penalties - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores) - - scores = tf.math.multiply(scores, score_penalties) - - return scores - - -class TFNoBadWordsLogitsProcessor(TFLogitsProcessor): - """ - [`TFLogitsProcessor`] that enforces that specified sequences will never be sampled. - - Args: - bad_words_ids (`list[list[int]]`): - List of list of token ids that are not allowed to be generated. In order to get the tokens of the words - that should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing - the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space` - argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from - `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers). - eos_token_id (`int`): - The id of the *end-of-sequence* token. - """ - - def __init__(self, bad_words_ids: list[list[int]], eos_token_id: int): - if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0: - raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.") - if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids): - raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.") - if any( - any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids) - for bad_word_ids in bad_words_ids - ): - raise ValueError( - f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." - ) - - # stores the information about bad words in three tensors: - # 1. a rectangular tensor with the forbidden sequences (padded with `-1`), for full data comparisons - self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1) - # 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons - bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids] - if any(word_len == 0 for word_len in bad_word_seqs_len): - raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list") - self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32) - # 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned - self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids]) - - def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor: - def _tokens_match(bad_word_seq_number): - def _len_one(): - # If the bad sequence only has one token, always mask it - return tf.cond( - tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1), - lambda: tf.ones((), dtype=tf.bool), - _len_greater_than_cur_len, - ) - - def _len_greater_than_cur_len(): - # Otherwise, if the bad sequence is longer than the current length they can't ever match - return tf.cond( - tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], tf.shape(row_input_ids)[0]), - lambda: tf.zeros((), dtype=tf.bool), - _match_found, - ) - - def _match_found(): - # Finally, runs the actual comparison. Can only be called if the previous comparisons do not yield - # an answer (otherwise we get indexing exceptions) - compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1 - return tf.cond( - tf.math.reduce_all( - tf.math.equal( - row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len] - ) - ), - lambda: tf.ones((), dtype=tf.bool), - lambda: tf.zeros((), dtype=tf.bool), - ) - - match = _len_one() - return match - - # Compares the current row against all bad word sequences, obtaining a mask with the matches. - match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool) - row_banned_tokens = self.seq_forbidden_tokens[match_mask] - return row_banned_tokens - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - # We want to mask some banned tokens, at a score level. Since the banned tokens depend on the previous - # `input_ids`, they may have a different length for each row, and they may even be empty for some rows. - # To remain simple and XLA-compatible, we work on a per-row fashion. - # TODO (Joao): this function might trigger XLA retracing as `cur_len` increases. Fix it if it becomes - # a frequent choke point. (make `cur_len` a tensor?) - def _get_row_updated_score(row_inputs: tuple[tf.Tensor]) -> tf.Tensor: - row_input_ids, row_score = row_inputs - banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len]) - banned_tokens_mask = tf.scatter_nd( - indices=tf.expand_dims(banned_tokens, axis=-1), - updates=tf.ones_like(banned_tokens, dtype=tf.bool), - shape=row_score.shape, - ) - row_score = tf.where(banned_tokens_mask, -float("inf"), row_score) - return row_score - - scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32) - return scores - - -class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor): - r""" - [`TFLogitsProcessor`] that enforces no repetition of n-grams. See - [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). - - Args: - ngram_size (`int`): - All ngrams of size `ngram_size` can only occur once. - """ - - def __init__(self, ngram_size: int): - if not isinstance(ngram_size, int) or ngram_size <= 0: - raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") - self.ngram_size = ngram_size - - def calc_banned_ngram_tokens(self, input_ids, num_hypos, cur_len): - # Copied from fairseq for no_repeat_ngram in beam_search - if cur_len + 1 < self.ngram_size: - # return no banned tokens if we haven't generated ngram_size tokens yet - return [[] for _ in range(num_hypos)] - generated_ngrams = [{} for _ in range(num_hypos)] - prev_input_ids = input_ids[:, :cur_len] - for idx in range(num_hypos): - gen_tokens = prev_input_ids[idx].numpy().tolist() - generated_ngram = generated_ngrams[idx] - for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]): - prev_ngram_tuple = tuple(ngram[:-1]) - generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] - - def _get_generated_ngrams(hypo_idx): - # Before decoding the next token, prevent decoding of ngrams that have already appeared - start_idx = cur_len + 1 - self.ngram_size - ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist()) - return generated_ngrams[hypo_idx].get(ngram_idx, []) - - banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] - - return banned_tokens - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - # TODO (joao): enable XLA on this logits processor. See discussion and attempts in - # https://github.com/huggingface/transformers/pull/16974 - if not tf.executing_eagerly(): - raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.") - - batch_size, vocab_size = scores.shape - banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len) - - # create banned_tokens boolean mask - banned_tokens_indices_mask = [] - for banned_tokens_slice in banned_tokens: - banned_tokens_indices_mask.append([token in banned_tokens_slice for token in range(vocab_size)]) - - scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores) - - return scores - - -class TFForcedBOSTokenLogitsProcessor(TFLogitsProcessor): - r""" - [`TFLogitsProcessor`] that enforces the specified token as the first generated token. - - Args: - bos_token_id (`int`): - The id of the token to force as the first generated token. - """ - - def __init__(self, bos_token_id: int): - if bos_token_id < 0: - raise ValueError(f"The forced bos token id must be a non-negative integer, got {bos_token_id}") - self.bos_token_id = bos_token_id - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - if cur_len == 1: - batch_size, num_tokens = scores.shape - # sets the score to 0 in the bos_token_id column - scores = tf.zeros((batch_size, 1)) - # sets the score to -inf everywhere else - if self.bos_token_id > 0: - scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.bos_token_id)), scores), axis=-1) - if self.bos_token_id < (num_tokens - 1): - scores = tf.concat( - (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.bos_token_id))), - axis=-1, - ) - return scores - - -class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor): - r""" - [`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached. - - Args: - max_length (`int`): - The maximum length of the sequence to be generated. - eos_token_id (`int`): - The id of the token to force as the last generated token when `max_length` is reached. - """ - - def __init__(self, max_length: int, eos_token_id: int): - self.max_length = max_length - if eos_token_id < 0: - raise ValueError(f"The forced eos token id must be a non-negative integer, got {eos_token_id}") - self.eos_token_id = eos_token_id - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - if cur_len == self.max_length - 1: - batch_size, num_tokens = scores.shape - # sets the score to 0 in the eos_token_id column - scores = tf.zeros((batch_size, 1)) - # sets the score to -inf everywhere else - if self.eos_token_id > 0: - scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.eos_token_id)), scores), axis=-1) - if self.eos_token_id < (num_tokens - 1): - scores = tf.concat( - (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.eos_token_id))), - axis=-1, - ) - return scores - - -class TFSuppressTokensAtBeginLogitsProcessor(TFLogitsProcessor): - r""" - [`TFSuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts - generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not - sampled at the beginning of the generation. - """ - - def __init__(self, begin_suppress_tokens, begin_index): - self.begin_suppress_tokens = list(begin_suppress_tokens) - self.begin_index = begin_index - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - suppressed_indices = [] - for token in self.begin_suppress_tokens: - if token < scores.shape[-1]: # to ensure we don't go beyond the vocab size - suppressed_indices.extend([[i, token] for i in range(scores.shape[0])]) - - if len(suppressed_indices) > 0: - scores = tf.cond( - tf.equal(cur_len, self.begin_index), - lambda: tf.tensor_scatter_nd_update( - scores, - indices=suppressed_indices, - updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))], - ), - lambda: scores, - ) - return scores - - -class TFSuppressTokensLogitsProcessor(TFLogitsProcessor): - r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they - are not sampled.""" - - def __init__(self, suppress_tokens): - self.suppress_tokens = list(suppress_tokens) - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - suppressed_indices = [] - for token in self.suppress_tokens: - if token < scores.shape[-1]: # to ensure we don't go beyond the vocab size - suppressed_indices.extend([[i, token] for i in range(scores.shape[0])]) - - if len(suppressed_indices) > 0: - scores = tf.tensor_scatter_nd_update( - scores, - indices=[[i, token] for i in range(scores.shape[0]) for token in self.suppress_tokens], - updates=[-float("inf") for _ in range(scores.shape[0] * len(self.suppress_tokens))], - ) - return scores - - -class TFForceTokensLogitsProcessor(TFLogitsProcessor): - r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token - indices that will be forced before sampling. The processor will set their log probs to `0` and all other tokens to - `-inf` so that they are sampled at their corresponding index.""" - - def __init__(self, force_token_map: list[list[int]]): - force_token_map = dict(force_token_map) - # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the - # index of the array corresponds to the index of the token to be forced, for XLA compatibility. - # Indexes without forced tokens will have an negative value. - force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1 - for index, token in force_token_map.items(): - if token is not None: - force_token_array[index] = token - self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32) - - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - def _force_token(generation_idx): - batch_size = scores.shape[0] - current_token = self.force_token_array[generation_idx] - - new_scores = tf.zeros_like(scores, dtype=scores.dtype) + tf.constant([scores.dtype.min]) - indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1) - updates = tf.zeros((batch_size,), dtype=scores.dtype) - new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates) - return new_scores - - scores = tf.cond( - tf.greater_equal(cur_len, tf.shape(self.force_token_array)[0]), - # If the current length is geq than the length of force_token_array, the processor does nothing. - lambda: tf.identity(scores), - # Otherwise, it may force a certain token. - lambda: tf.cond( - tf.greater_equal(self.force_token_array[cur_len], 0), - # Only valid (positive) tokens are forced - lambda: _force_token(cur_len), - # Otherwise, the processor does nothing. - lambda: scores, - ), - ) - return scores diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py deleted file mode 100644 index be51c9cd9f43..000000000000 --- a/src/transformers/generation/tf_utils.py +++ /dev/null @@ -1,3132 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import inspect -import warnings -from dataclasses import dataclass -from typing import Any, Optional, Union - -import numpy as np -import tensorflow as tf -from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice - -from ..modeling_tf_outputs import TFCausalLMOutputWithPast, TFSeq2SeqLMOutput -from ..models.auto import ( - TF_MODEL_FOR_CAUSAL_LM_MAPPING, - TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, - TF_MODEL_FOR_VISION_2_SEQ_MAPPING, -) -from ..tf_utils import shape_list, stable_softmax -from ..utils import ModelOutput, logging -from .configuration_utils import GenerationConfig -from .tf_logits_process import ( - TFForcedBOSTokenLogitsProcessor, - TFForcedEOSTokenLogitsProcessor, - TFForceTokensLogitsProcessor, - TFLogitsProcessorList, - TFMinLengthLogitsProcessor, - TFNoBadWordsLogitsProcessor, - TFNoRepeatNGramLogitsProcessor, - TFRepetitionPenaltyLogitsProcessor, - TFSuppressTokensAtBeginLogitsProcessor, - TFSuppressTokensLogitsProcessor, - TFTemperatureLogitsWarper, - TFTopKLogitsWarper, - TFTopPLogitsWarper, -) - - -logger = logging.get_logger(__name__) - - -@dataclass -class TFGreedySearchDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using greedy search. - - - Args: - sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each - generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, generated_length, hidden_size)`. - """ - - sequences: Optional[tf.Tensor] = None - scores: Optional[tuple[tf.Tensor]] = None - attentions: Optional[tuple[tuple[tf.Tensor]]] = None - hidden_states: Optional[tuple[tuple[tf.Tensor]]] = None - - -@dataclass -class TFGreedySearchEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention - weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the - encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - - Args: - sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each - generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, generated_length, hidden_size)`. - """ - - sequences: Optional[tf.Tensor] = None - scores: Optional[tuple[tf.Tensor]] = None - encoder_attentions: Optional[tuple[tf.Tensor]] = None - encoder_hidden_states: Optional[tuple[tf.Tensor]] = None - decoder_attentions: Optional[tuple[tuple[tf.Tensor]]] = None - cross_attentions: Optional[tuple[tuple[tf.Tensor]]] = None - decoder_hidden_states: Optional[tuple[tuple[tf.Tensor]]] = None - - -@dataclass -class TFSampleDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using sampling. - - - Args: - sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each - generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`. - attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`. - """ - - sequences: Optional[tf.Tensor] = None - scores: Optional[tuple[tf.Tensor]] = None - attentions: Optional[tuple[tuple[tf.Tensor]]] = None - hidden_states: Optional[tuple[tuple[tf.Tensor]]] = None - - -@dataclass -class TFSampleEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of - the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states - attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - - Args: - sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each - generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`. - encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size*num_return_sequences, - num_heads, sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size*num_return_sequences, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size*num_return_sequences, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`. - """ - - sequences: Optional[tf.Tensor] = None - scores: Optional[tuple[tf.Tensor]] = None - encoder_attentions: Optional[tuple[tf.Tensor]] = None - encoder_hidden_states: Optional[tuple[tf.Tensor]] = None - decoder_attentions: Optional[tuple[tuple[tf.Tensor]]] = None - cross_attentions: Optional[tuple[tuple[tf.Tensor]]] = None - decoder_hidden_states: Optional[tuple[tuple[tf.Tensor]]] = None - - -@dataclass -class TFBeamSearchDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using beam search. - - Args: - sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`tf.Tensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log - softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this - beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), - with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`. - beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `tf.Tensor` of shape - `(batch_size*num_return_sequences, sequence_length)`. - attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - """ - - sequences: Optional[tf.Tensor] = None - sequences_scores: Optional[tf.Tensor] = None - scores: Optional[tuple[tf.Tensor]] = None - beam_indices: Optional[tf.Tensor] = None - attentions: Optional[tuple[tuple[tf.Tensor]]] = None - hidden_states: Optional[tuple[tuple[tf.Tensor]]] = None - - -@dataclass -class TFBeamSearchEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights - of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states - attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - Args: - sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`tf.Tensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log - softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this - beam. `Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), - with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `tf.Tensor` of shape - `(batch_size*num_return_sequences, sequence_length)`. - encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, - sequence_length)`. - cross_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - """ - - sequences: Optional[tf.Tensor] = None - sequences_scores: Optional[tf.Tensor] = None - scores: Optional[tuple[tf.Tensor]] = None - beam_indices: Optional[tf.Tensor] = None - encoder_attentions: Optional[tuple[tf.Tensor]] = None - encoder_hidden_states: Optional[tuple[tf.Tensor]] = None - decoder_attentions: Optional[tuple[tuple[tf.Tensor]]] = None - cross_attentions: Optional[tuple[tuple[tf.Tensor]]] = None - decoder_hidden_states: Optional[tuple[tuple[tf.Tensor]]] = None - - -@dataclass -class TFBeamSampleDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using beam sample. - - Args: - sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`tf.Tensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log - softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this - beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), - with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`. - beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `tf.Tensor` of shape - `(batch_size*num_return_sequences, sequence_length)`. - attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. - """ - - sequences: Optional[tf.Tensor] = None - sequences_scores: Optional[tf.Tensor] = None - scores: Optional[tuple[tf.Tensor]] = None - beam_indices: Optional[tf.Tensor] = None - attentions: Optional[tuple[tuple[tf.Tensor]]] = None - hidden_states: Optional[tuple[tuple[tf.Tensor]]] = None - - -@dataclass -class TFBeamSampleEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention - weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the - encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - Args: - sequences (`tf.Tensor` of shape `(batch_size*num_beams, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`tf.Tensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log - softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this - beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), - with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `tf.Tensor` of shape - `(batch_size*num_return_sequences, sequence_length)`. - encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size*num_beams, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. - """ - - sequences: Optional[tf.Tensor] = None - sequences_scores: Optional[tf.Tensor] = None - scores: Optional[tuple[tf.Tensor]] = None - beam_indices: Optional[tf.Tensor] = None - encoder_attentions: Optional[tuple[tf.Tensor]] = None - encoder_hidden_states: Optional[tuple[tf.Tensor]] = None - decoder_attentions: Optional[tuple[tuple[tf.Tensor]]] = None - cross_attentions: Optional[tuple[tuple[tf.Tensor]]] = None - decoder_hidden_states: Optional[tuple[tuple[tf.Tensor]]] = None - - -@dataclass -class TFContrastiveSearchDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using contrastive search. - - Args: - sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each - generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, generated_length, hidden_size)`. - """ - - sequences: Optional[tf.Tensor] = None - scores: Optional[tuple[tf.Tensor]] = None - attentions: Optional[tuple[tuple[tf.Tensor]]] = None - hidden_states: Optional[tuple[tuple[tf.Tensor]]] = None - - -@dataclass -class TFContrastiveSearchEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using contrastive search. Hidden states and attention - weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the - encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - Args: - sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each - generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `tf.Tensor` of shape `(batch_size, generated_length, hidden_size)`. - """ - - sequences: Optional[tf.Tensor] = None - scores: Optional[tuple[tf.Tensor]] = None - encoder_attentions: Optional[tuple[tf.Tensor]] = None - encoder_hidden_states: Optional[tuple[tf.Tensor]] = None - decoder_attentions: Optional[tuple[tuple[tf.Tensor]]] = None - cross_attentions: Optional[tuple[tuple[tf.Tensor]]] = None - decoder_hidden_states: Optional[tuple[tuple[tf.Tensor]]] = None - - -TFGreedySearchOutput = Union[TFGreedySearchEncoderDecoderOutput, TFGreedySearchDecoderOnlyOutput] -TFSampleOutput = Union[TFSampleEncoderDecoderOutput, TFSampleDecoderOnlyOutput] -TFBeamSearchOutput = Union[TFBeamSearchEncoderDecoderOutput, TFBeamSearchDecoderOnlyOutput] -TFBeamSampleOutput = Union[TFBeamSampleEncoderDecoderOutput, TFBeamSampleDecoderOnlyOutput] -TFContrastiveSearchOutput = Union[TFContrastiveSearchEncoderDecoderOutput, TFContrastiveSearchDecoderOnlyOutput] -TFGenerateOutput = Union[ - TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, TFContrastiveSearchOutput -] - - -class TFGenerationMixin: - """ - A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`]. - - The class exposes [`~generation.TFGenerationMixin.generate`], which can be used for: - - *greedy decoding* by calling [`~generation.TFGenerationMixin.greedy_search`] if `num_beams=1` and - `do_sample=False` - - *contrastive search* by calling [`~generation.TFGenerationMixin.contrastive_search`] if `penalty_alpha>0` and - `top_k>1` - - *multinomial sampling* by calling [`~generation.TFGenerationMixin.sample`] if `num_beams=1` and - `do_sample=True` - - *beam-search decoding* by calling [`~generation.TFGenerationMixin.beam_search`] if `num_beams>1` - - You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To - learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). - """ - - _seed_generator = None - - @property - def seed_generator(self): - warnings.warn("`seed_generator` is deprecated and will be removed in a future version.", UserWarning) - if self._seed_generator is None: - self._seed_generator = tf.random.Generator.from_non_deterministic_state() - return self._seed_generator - - supports_xla_generation = True - - def prepare_inputs_for_generation(self, *args, **kwargs): - raise NotImplementedError( - "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." - ) - - def compute_transition_scores( - self, - sequences: tf.Tensor, - scores: tuple[tf.Tensor], - beam_indices: Optional[tf.Tensor] = None, - normalize_logits: bool = False, - ) -> tf.Tensor: - """ - Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was - used). This is a convenient method to quickly obtain the scores of the selected tokens at generation time. - - Parameters: - sequences (`tf.Tensor`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or - shorter if all batches finished early due to the `eos_token_id`. - scores (`tuple(tf.Tensor)`): - Transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens Tuple of - `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), with each - tensor of shape `(batch_size*num_beams, config.vocab_size)`. - beam_indices (`tf.Tensor`, *optional*): - Beam indices of generated token id at each generation step. `tf.Tensor` of shape - `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at - generate-time. - normalize_logits (`bool`, *optional*, defaults to `False`): - Whether to normalize the logits (which, for legacy reasons, may be unnormalized). - - Return: - `tf.Tensor`: A `tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing - the transition scores (logits) - - Examples: - - ```python - >>> from transformers import GPT2Tokenizer, TFAutoModelForCausalLM - >>> import numpy as np - - >>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - >>> model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer.pad_token_id = tokenizer.eos_token_id - >>> inputs = tokenizer(["Today is"], return_tensors="tf") - - >>> # Example 1: Print the scores for each token generated with Greedy Search - >>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True) - >>> transition_scores = model.compute_transition_scores( - ... outputs.sequences, outputs.scores, normalize_logits=True - ... ) - >>> # input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for - >>> # encoder-decoder models, like BART or T5. - >>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] - >>> generated_tokens = outputs.sequences[:, input_length:] - >>> for tok, score in zip(generated_tokens[0], transition_scores[0]): - ... # | token | token string | logits | probability - ... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}") - | 262 | the | -1.414 | 24.33% - | 1110 | day | -2.609 | 7.36% - | 618 | when | -2.010 | 13.40% - | 356 | we | -1.859 | 15.58% - | 460 | can | -2.508 | 8.14% - - >>> # Example 2: Reconstruct the sequence scores from Beam Search - >>> outputs = model.generate( - ... **inputs, - ... max_new_tokens=5, - ... num_beams=4, - ... num_return_sequences=4, - ... return_dict_in_generate=True, - ... output_scores=True, - ... ) - >>> transition_scores = model.compute_transition_scores( - ... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False - ... ) - >>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores. - >>> # Tip: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the - >>> # use case, you might want to recompute it with `normalize_logits=True`. - >>> output_length = np.sum(transition_scores.numpy() < 0, axis=1) - >>> length_penalty = model.generation_config.length_penalty - >>> reconstructed_scores = np.sum(transition_scores, axis=1) / (output_length**length_penalty) - >>> print(np.allclose(outputs.sequences_scores, reconstructed_scores)) - True - ```""" - # 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent - # to a beam search approach were the first (and only) beam is always selected - if beam_indices is None: - beam_indices = tf.tile(tf.expand_dims(tf.range(scores[0].shape[0]), axis=1), [1, len(scores)]) - - # 2. reshape scores as [batch_size, vocab_size, # generation steps] with # generation steps being - # seq_len - input_length - scores = tf.transpose(tf.reshape(tf.stack(scores), (len(scores), -1)), (1, 0)) - scores = tf.reshape(scores, (-1, self.config.vocab_size, scores.shape[-1])) - - # 3. Optionally normalize the logits (across the vocab dimension) - if normalize_logits: - scores = tf.nn.log_softmax(scores, axis=1) - - # 4. cut beam_indices to longest beam length - beam_indices_mask = beam_indices < 0 - max_beam_length = tf.math.reduce_max( - tf.math.reduce_sum((1 - tf.cast(beam_indices_mask, dtype=tf.int32)), axis=-1) - ) - beam_indices = beam_indices[:, -max_beam_length:] - beam_indices_mask = beam_indices_mask[:, -max_beam_length:] - - # 5. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards - beam_indices = tf.where(beam_indices_mask, 0, beam_indices) - - # 6. Define which indices contributed to scores - cut_idx = sequences.shape[-1] - max_beam_length - token_indices = sequences[:, cut_idx:] - gen_step_idx = tf.broadcast_to(tf.range(scores.shape[-1]), token_indices.shape) - indices = tf.stack([beam_indices, token_indices, gen_step_idx], axis=-1) - - # 7. Compute scores - transition_scores = tf.gather_nd(scores, indices) - - # 8. Mask out transition_scores of beams that stopped early - transition_scores = tf.where(beam_indices_mask, 0, transition_scores) - - return transition_scores - - def _validate_model_class(self): - """ - Confirms that the model class is compatible with generation. If not, raises an exception that points to the - right class to use. - """ - if not self.can_generate(): - generate_compatible_mappings = [ - TF_MODEL_FOR_CAUSAL_LM_MAPPING, - TF_MODEL_FOR_VISION_2_SEQ_MAPPING, - TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, - ] - generate_compatible_classes = set() - for model_mapping in generate_compatible_mappings: - supported_models = model_mapping.get(type(self.config), default=None) - if supported_models is not None: - generate_compatible_classes.add(supported_models.__name__) - exception_message = ( - f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " - "it doesn't have a language model head." - ) - if generate_compatible_classes: - exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" - raise TypeError(exception_message) - - def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): - """Validates model kwargs for generation. Generate argument typos will also be caught here.""" - # Excludes arguments that are handled before calling any model function - if self.config.is_encoder_decoder: - for key in ["decoder_input_ids"]: - model_kwargs.pop(key, None) - - unused_model_args = [] - model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) - # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If - # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) - if "kwargs" in model_args or "model_kwargs" in model_args: - model_args |= set(inspect.signature(self.call).parameters) - for key, value in model_kwargs.items(): - if value is not None and key not in model_args: - unused_model_args.append(key) - - if unused_model_args: - raise ValueError( - f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" - " generate arguments will also show up in this list)" - ) - - def generate( - self, - inputs: Optional[tf.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[TFLogitsProcessorList] = None, - seed=None, - **kwargs, - ) -> Union[TFGenerateOutput, tf.Tensor]: - r""" - Generates sequences of token ids for models with a language modeling head. - - - - Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the - model's default generation configuration. You can override any `generation_config` by passing the corresponding - parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - - For an overview of generation strategies and code examples, check out the [following - guide](../generation_strategies). - - - - Parameters: - inputs (`tf.Tensor` of varying shape depending on the modality, *optional*): - The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the - method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` - should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of - `input_ids`, `input_values`, `input_features`, or `pixel_values`. - generation_config (`~generation.GenerationConfig`, *optional*): - The generation configuration to be used as base parametrization for the generation call. `**kwargs` - passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which had the following loading - priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model - configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s - default values, whose documentation should be checked to parameterize generation. - logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and - generation config. If a logit processor is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - seed (`list[int]`, *optional*): - Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the - `seed` argument from stateless functions in `tf.random`. - kwargs (`dict[str, Any]`, *optional*): - Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be - forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder - specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. - - Return: - [`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when - `config.return_dict_in_generate=True`) or a `tf.Tensor`. - - If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.TFGreedySearchDecoderOnlyOutput`], - - [`~generation.TFSampleDecoderOnlyOutput`], - - [`~generation.TFBeamSearchDecoderOnlyOutput`], - - [`~generation.TFBeamSampleDecoderOnlyOutput`] - - If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.TFGreedySearchEncoderDecoderOutput`], - - [`~generation.TFSampleEncoderDecoderOutput`], - - [`~generation.TFBeamSearchEncoderDecoderOutput`], - - [`~generation.TFBeamSampleEncoderDecoderOutput`] - - """ - - # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - self._validate_model_class() - - # priority: `generation_config` argument > `model.generation_config` (the default generation config) - if generation_config is None: - # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, - # two conditions must be met - # 1) the generation config must have been created from the model config (`_from_model_config` field); - # 2) the generation config must have seen no modification since its creation (the hash is the same). - if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash( - self.generation_config - ): - new_generation_config = GenerationConfig.from_model_config(self.config) - if new_generation_config != self.generation_config: - warnings.warn( - "You have modified the pretrained model configuration to control generation. This is a" - " deprecated strategy to control generation and will be removed soon, in a future version." - " Please use and modify the model generation configuration (see" - " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" - ) - self.generation_config = new_generation_config - generation_config = self.generation_config - - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - self._validate_model_kwargs(model_kwargs.copy()) - - # 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models) - if inputs is not None: - if isinstance(inputs, tf.Tensor) and inputs.dtype.is_floating: - pass - elif isinstance(inputs, np.ndarray) and np.issubdtype(inputs.dtype, np.floating): - pass - else: - inputs = tf.cast(inputs, tf.int32) - if model_kwargs.get("attention_mask") is not None: - model_kwargs["attention_mask"] = tf.cast(model_kwargs["attention_mask"], tf.int32) - if "decoder_input_ids" in model_kwargs: - if ( - isinstance(model_kwargs["decoder_input_ids"], tf.Tensor) - and model_kwargs["decoder_input_ids"].dtype.is_floating - ): - pass - elif isinstance(model_kwargs["decoder_input_ids"], np.ndarray) and np.issubdtype( - model_kwargs["decoder_input_ids"].dtype, np.floating - ): - pass - else: - model_kwargs["decoder_input_ids"] = tf.cast(model_kwargs["decoder_input_ids"], tf.int32) - - # 3. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() - - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask") is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - generation_config.pad_token_id = eos_token_id - - use_xla = not tf.executing_eagerly() - if use_xla and not self.supports_xla_generation: - raise ValueError( - "The selected model does not support Graph mode nor XLA generation (e.g. from tf.function())" - ) - - # 4. Define model inputs - inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( - inputs, generation_config.bos_token_id, model_kwargs - ) - # inputs_ids now has to be defined and cannot be None anymore - batch_size = shape_list(inputs_tensor)[0] - - # 5. Prepare other model kwargs - model_kwargs["output_attentions"] = generation_config.output_attentions - model_kwargs["output_hidden_states"] = generation_config.output_hidden_states - model_kwargs["use_cache"] = generation_config.use_cache - - accepts_attention_mask = "attention_mask" in set(inspect.signature(self.call).parameters.keys()) - requires_attention_mask = "encoder_outputs" not in model_kwargs - - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id - ) - - # decoder-only models should use left-padding for generation - if not self.config.is_encoder_decoder: - if generation_config.pad_token_id is not None and tf.math.reduce_any( - inputs_tensor[:, -1] == generation_config.pad_token_id - ): - logger.warning( - "A decoder-only architecture is being used, but right-padding was detected! For correct " - "generation results, please set `padding_side='left'` when initializing the tokenizer." - ) - if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: - # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` - model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, model_kwargs, model_input_name - ) - - # 6. Prepare model inputs which will be used for auto-regressive generation - if self.config.is_encoder_decoder: - input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( - batch_size=batch_size, - model_input_name=model_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, - ) - else: - input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") - - # 7. Prepare `max_length` depending on other stopping criteria. - input_ids_seq_length = shape_list(input_ids)[-1] - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: - # 20 is the default max_length of the generation config - warnings.warn( - f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " - "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - if not has_default_max_length and generation_config.max_length is not None: - logger.warning( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - - # If the input length is a tensor (i.e. dynamic length), skip length checks - if not isinstance(input_ids_seq_length, tf.Tensor): - if ( - generation_config.min_length is not None - and generation_config.min_length > generation_config.max_length - ): - raise ValueError( - f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger" - f" than the maximum length ({generation_config.max_length})" - ) - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing`max_new_tokens`." - ) - - # 8. determine generation mode - is_contrastive_search_gen_mode = ( - generation_config.top_k is not None - and generation_config.top_k > 1 - and generation_config.do_sample is False - and generation_config.penalty_alpha is not None - and generation_config.penalty_alpha > 0 - ) - is_greedy_gen_mode = ( - not is_contrastive_search_gen_mode - and (generation_config.num_beams == 1) - and generation_config.do_sample is False - ) - is_beam_gen_mode = ( - not is_contrastive_search_gen_mode - and (generation_config.num_beams > 1) - and generation_config.do_sample is False - ) - is_sample_gen_mode = (generation_config.num_beams == 1) and generation_config.do_sample is True - is_beam_sample_gen_mode = (generation_config.num_beams > 1) and generation_config.do_sample is True - - # 9. prepare distribution pre_processing samplers - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - logits_processor=logits_processor, - ) - - # 10. go into different generation modes - if is_greedy_gen_mode: - if generation_config.num_return_sequences > 1: - raise ValueError( - f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" - " greedy search." - ) - # 11. run greedy search - return self.greedy_search( - input_ids, - max_length=generation_config.max_length, - pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, - logits_processor=logits_processor, - output_scores=generation_config.output_scores, - return_dict_in_generate=generation_config.return_dict_in_generate, - **model_kwargs, - ) - elif is_contrastive_search_gen_mode: - if generation_config.num_return_sequences > 1: - raise ValueError( - f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" - " contrastive search." - ) - # 11. run contrastive search - return self.contrastive_search( - input_ids, - top_k=generation_config.top_k, - penalty_alpha=generation_config.penalty_alpha, - logits_processor=logits_processor, - max_length=generation_config.max_length, - pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, - output_scores=generation_config.output_scores, - return_dict_in_generate=generation_config.return_dict_in_generate, - **model_kwargs, - ) - elif is_sample_gen_mode: - # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config=generation_config) - - # 12. expand input_ids with `num_return_sequences` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_return_sequences, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - - # 13. run sample - return self.sample( - input_ids, - logits_processor=logits_processor, - logits_warper=logits_warper, - max_length=generation_config.max_length, - pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, - seed=seed, - output_scores=generation_config.output_scores, - return_dict_in_generate=generation_config.return_dict_in_generate, - **model_kwargs, - ) - - elif is_beam_gen_mode: - if generation_config.num_beams < generation_config.num_return_sequences: - raise ValueError( - "Beam search decoding cannot return more sequences than it has beams. Please set num_beams >=" - f" num_return_sequences, got {generation_config.num_beams} and" - f" {generation_config.num_return_sequences} (respectively)" - ) - - # 11. broadcast inputs to the desired number of beams - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - expand_in_new_axis=True, - **model_kwargs, - ) - - # 12. run beam search - return self.beam_search( - input_ids, - max_length=generation_config.max_length, - pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, - length_penalty=generation_config.length_penalty, - early_stopping=generation_config.early_stopping, - logits_processor=logits_processor, - output_scores=generation_config.output_scores, - return_dict_in_generate=generation_config.return_dict_in_generate, - num_return_sequences=generation_config.num_return_sequences, - **model_kwargs, - ) - - elif is_beam_sample_gen_mode: - if generation_config.num_beams < generation_config.num_return_sequences: - raise ValueError( - "Beam search decoding cannot return more sequences than it has beams. Please set num_beams >=" - f" num_return_sequences, got {generation_config.num_beams} and" - f" {generation_config.num_return_sequences} (respectively)" - ) - - # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config=generation_config) - - # 12. broadcast inputs to the desired number of beams - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - expand_in_new_axis=True, - **model_kwargs, - ) - - # 13. run beam sample (beam search with sampling) - return self.beam_search( - input_ids, - do_sample=True, - max_length=generation_config.max_length, - pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, - length_penalty=generation_config.length_penalty, - early_stopping=generation_config.early_stopping, - logits_processor=logits_processor, - logits_warper=logits_warper, - output_scores=generation_config.output_scores, - return_dict_in_generate=generation_config.return_dict_in_generate, - num_return_sequences=generation_config.num_return_sequences, - **model_kwargs, - ) - - def _prepare_attention_mask_for_generation( - self, - inputs: tf.Tensor, - pad_token_id: Optional[int], - eos_token_id: Optional[int], - ) -> tf.Tensor: - is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64) - is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id) - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id) - - # Check if input is input_ids and padded -> only then is attention_mask defined - if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: - return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32) - else: - return tf.ones(inputs.shape[:2], dtype=tf.int32) - - def _prepare_encoder_decoder_kwargs_for_generation( - self, inputs_tensor: tf.Tensor, model_kwargs, model_input_name: Optional[str] = None - ) -> dict[str, Any]: - # 1. get encoder and store encoder outputs - encoder = self.get_encoder() - - # 2. prepare encoder args and encoder kwargs from model kwargs - irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] - encoder_kwargs = { - argument: value - for argument, value in model_kwargs.items() - if not any(argument.startswith(p) for p in irrelevant_prefix) - } - encoder_signature = set(inspect.signature(encoder.call).parameters) - encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature - if not encoder_accepts_wildcard: - encoder_kwargs = { - argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature - } - - # 3. vision models don't use `attention_mask`. - encoder_kwargs["return_dict"] = True - encoder_kwargs[model_input_name] = inputs_tensor - if model_input_name != self.main_input_name: # in Keras, the first input must always be passed - encoder_kwargs[self.main_input_name] = None - encoder_outputs = encoder(**encoder_kwargs) - model_kwargs["encoder_outputs"] = encoder_outputs - - return model_kwargs - - def _prepare_decoder_input_ids_for_generation( - self, - batch_size: int, - model_input_name: str, - model_kwargs: dict[str, tf.Tensor], - decoder_start_token_id: Optional[int] = None, - bos_token_id: Optional[int] = None, - ) -> tuple[tf.Tensor, dict[str, tf.Tensor]]: - """Prepares `decoder_input_ids` for generation with encoder-decoder models""" - # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, - # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. - if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - decoder_input_ids = model_kwargs.pop("decoder_input_ids") - elif "input_ids" in model_kwargs and model_input_name != "input_ids": - decoder_input_ids = model_kwargs.pop("input_ids") - else: - decoder_input_ids = None - - # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) - decoder_input_ids_start = tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id - - # no user input -> use decoder_start_token_id as decoder_input_ids - if decoder_input_ids is None: - decoder_input_ids = decoder_input_ids_start - # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust - # decoder_attention_mask if provided) - elif tf.reduce_all(decoder_input_ids[:, 0] != decoder_start_token_id): - decoder_input_ids = tf.concat([decoder_input_ids_start, decoder_input_ids], axis=-1) - if "decoder_attention_mask" in model_kwargs: - decoder_attention_mask = model_kwargs["decoder_attention_mask"] - decoder_attention_mask = tf.concat( - (tf.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), - axis=-1, - ) - model_kwargs["decoder_attention_mask"] = decoder_attention_mask - - return decoder_input_ids, model_kwargs - - def _get_decoder_start_token_id( - self, decoder_start_token_id: Optional[int] = None, bos_token_id: Optional[int] = None - ) -> int: - # retrieve decoder_start_token_id for encoder-decoder models - # fall back to bos_token_id if necessary - decoder_start_token_id = ( - decoder_start_token_id - if decoder_start_token_id is not None - else self.generation_config.decoder_start_token_id - ) - bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id - - if decoder_start_token_id is not None: - return decoder_start_token_id - elif bos_token_id is not None: - return bos_token_id - raise ValueError( - "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." - ) - - @staticmethod - def _expand_inputs_for_generation( - expand_size: int = 1, - is_encoder_decoder: bool = False, - input_ids: Optional[tf.Tensor] = None, - expand_in_new_axis: bool = False, - **model_kwargs, - ) -> tuple[tf.Tensor, dict[str, Any]]: - """ - Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...] or [batch_size, expand_size, ...], - depending on `expand_in_new_axis`. Beam-based approaches expect this function to be used with - `expand_in_new_axis=True` - """ - - def _expand_tensor(tensor: tf.Tensor): - if expand_in_new_axis: - shape = shape_list(tensor) - return tf.broadcast_to(tensor[:, None], (shape[0], expand_size) + tuple(shape[1:])) - else: - return tf.repeat(tensor, expand_size, axis=0) - - def _expand_dict_for_generation(dict_to_expand): - for key in dict_to_expand: - if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], tf.Tensor): - dict_to_expand[key] = _expand_tensor(dict_to_expand[key]) - return dict_to_expand - - if input_ids is not None: - input_ids = _expand_tensor(input_ids) - - model_kwargs = _expand_dict_for_generation(model_kwargs) - - if is_encoder_decoder: - if model_kwargs.get("encoder_outputs") is None: - raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") - model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) - - return input_ids, model_kwargs - - def _prepare_model_inputs( - self, - inputs: Optional[tf.Tensor] = None, - bos_token_id: Optional[int] = None, - model_kwargs: Optional[dict[str, tf.Tensor]] = None, - ) -> tuple[tf.Tensor, Optional[str], dict[str, tf.Tensor]]: - """ - This function extracts the model-specific `inputs` for generation. - """ - # 1. retrieve all kwargs that are non-None or non-model input related. - # some encoder-decoder models have different names for model and encoder - if ( - self.config.is_encoder_decoder - and hasattr(self, "encoder") - and hasattr(self.encoder, "main_input_name") - and self.encoder.main_input_name != self.main_input_name - ): - input_name = self.encoder.main_input_name - else: - input_name = self.main_input_name - - model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} - - # 2. check whether model_input_name is passed as kwarg - # if yes and `inputs` is None use kwarg inputs - inputs_kwarg = model_kwargs.pop(input_name, None) - if inputs_kwarg is not None and inputs is not None: - raise ValueError( - f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " - f"Make sure to either pass {inputs} or {input_name}=..." - ) - elif inputs_kwarg is not None: - inputs = inputs_kwarg - - # 3. In the presence of `inputs_embeds` for text models: - # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model - # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with - # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`) - # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and - # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states. - if input_name == "input_ids" and "inputs_embeds" in model_kwargs: - if not self.config.is_encoder_decoder: - has_inputs_embeds_forwarding = "inputs_embeds" in set( - inspect.signature(self.prepare_inputs_for_generation).parameters.keys() - ) - if not has_inputs_embeds_forwarding: - raise ValueError( - f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} " - "doesn't have its forwarding implemented. See the GPT2 implementation for an example " - "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" - ) - # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of - # the attention mask) can rely on the actual model input. - model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( - inputs, bos_token_id, model_kwargs=model_kwargs - ) - else: - if inputs is not None: - raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") - inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" - - # 4. if `inputs` is still None, try to create `input_ids` from BOS token - inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) - - return inputs, input_name, model_kwargs - - def _maybe_initialize_input_ids_for_generation( - self, - inputs: Optional[tf.Tensor] = None, - bos_token_id: Optional[int] = None, - model_kwargs: Optional[dict[str, tf.Tensor]] = None, - ) -> tf.Tensor: - """Initializes input ids for generation, if necessary.""" - if inputs is not None: - return inputs - - encoder_outputs = model_kwargs.get("encoder_outputs") - if self.config.is_encoder_decoder and encoder_outputs is not None: - # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding - shape = encoder_outputs.last_hidden_state.shape[:-1] - return tf.ones(shape, dtype=tf.int32) * -100 - - if bos_token_id is None: - raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") - - # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with - # soft-prompting or in multimodal implementations built on top of decoder-only language models. - batch_size = 1 - for value in model_kwargs.values(): - if isinstance(value, tf.Tensor): - batch_size = value.shape[0] - break - return tf.ones((batch_size, 1), dtype=tf.int32) * bos_token_id - - @staticmethod - def _extract_past_from_model_output(outputs: ModelOutput): - past_key_values = None - if "past_key_values" in outputs: - past_key_values = outputs.past_key_values - elif "mems" in outputs: - past_key_values = outputs.mems - elif "past_buckets_states" in outputs: - past_key_values = outputs.past_buckets_states - return past_key_values - - def _update_model_kwargs_for_generation( - self, outputs: ModelOutput, model_kwargs: dict[str, Any], is_encoder_decoder: bool = False - ) -> dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output(outputs) - - # update attention mask - if not is_encoder_decoder: - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = tf.concat( - [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1 - ) - - return model_kwargs - - def _update_model_kwargs_for_xla_generation( - self, - model_outputs: ModelOutput, - model_kwargs: dict[str, Any], - cur_len: int, - max_length: int, - batch_size: int, - is_encoder_decoder: bool = False, - batch_axis: int = 0, - ): - def _initialize_attention(model_kwargs, num_padding_values, is_encoder_decoder): - """initializes the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`""" - if is_encoder_decoder: - # One 1 for decoder_start_token_id, 0s for the currently-unfilled locations in the past_key_values tensor, - # 1s for the actual input_ids - decoder_attention_mask = tf.concat( - [ - tf.ones((batch_size, 1), dtype=tf.int32), - tf.zeros((batch_size, num_padding_values), dtype=tf.int32), - tf.ones((batch_size, 1), dtype=tf.int32), - ], - axis=1, - ) - mask = {"decoder_attention_mask": decoder_attention_mask} - else: - attention_mask = model_kwargs.pop("attention_mask") - # 0s for the currently-unfilled locations in the past_key_values tensor, 1s for the actual input_ids - attention_mask = tf.concat( - [ - attention_mask, - tf.zeros((batch_size, num_padding_values), dtype=attention_mask.dtype), - tf.ones((batch_size, 1), dtype=attention_mask.dtype), - ], - axis=1, - ) - mask = {"attention_mask": attention_mask} - return mask - - def _update_attention(model_kwargs, new_past_index, is_encoder_decoder): - """updates the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`""" - update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index - if is_encoder_decoder: - decoder_attention_mask = model_kwargs.pop("decoder_attention_mask") - decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype) - decoder_attention_mask = dynamic_update_slice( - decoder_attention_mask, decoder_attention_mask_update_slice, update_start - ) - mask = {"decoder_attention_mask": decoder_attention_mask} - else: - attention_mask = model_kwargs.pop("attention_mask") - attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype) - attention_mask = dynamic_update_slice(attention_mask, attention_mask_update_slice, update_start) - mask = {"attention_mask": attention_mask} - return mask - - def _initialize_past(past_key_values, num_padding_values, batch_axis): - """initialize past_key_values with zeros -- the structure depends on `batch_axis`""" - if batch_axis == 0: - padding_values = tf.constant([[0, 0], [0, 0], [0, num_padding_values], [0, 0]], dtype=tf.int32) - new_past = () - for past_layer in past_key_values: - new_past_layer = list(past_layer) - for i in range(len(new_past_layer[:2])): - new_past_layer[i] = tf.pad(past_layer[i], padding_values) - new_past += (tuple(new_past_layer),) - else: - padding_values = tf.scatter_nd(indices=[[3, 1]], updates=[num_padding_values], shape=(5, 2)) - new_past = list(past_key_values) - for i in range(len(past_key_values)): - new_past[i] = tf.pad(past_key_values[i], padding_values) - return new_past - - def _update_past(past_key_values, new_past_index, batch_axis): - if batch_axis == 0: - slice_start_base = tf.constant([0, 0, 1, 0]) - new_past = () - for past_layer in past_key_values: - new_past_layer = list(past_layer) - for i in range(len(new_past_layer[:2])): - update_slice = past_layer[i][:, :, -1:] - # Write the last slice to the first open location in the padded past_key_values array - # and then truncate the last slice off the array - new_past_layer[i] = dynamic_update_slice( - past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index - ) - new_past += (tuple(new_past_layer),) - else: - slice_start_base = tf.constant([0, 0, 0, 1, 0]) - new_past = [None for _ in range(len(past_key_values))] - for i in range(len(past_key_values)): - update_slice = past_key_values[i][:, :, :, -1:] - # Write the last slice to the first open location in the padded past_key_values array - # and then truncate the last slice off the array - new_past[i] = dynamic_update_slice( - past_key_values[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index - ) - return new_past - - past_key_values = self._extract_past_from_model_output(model_outputs) - if past_key_values is None: - raise ValueError( - "No known `past_key_values variable` found in model outputs (model outputs keys:" - f" {list(model_outputs.keys())})" - ) - is_past_initialized = model_kwargs.pop("past_key_values", None) is not None - - if not is_past_initialized: - # The padded version of `past_key_values` has a length of `max_length - 1`, as `past_key_values` holds information relative to - # previous autoregressive generation steps (step 0 has no past_key_values, step 1 has 1 past_key_values value, ..., the last step - # has `max_length - 1` past_key_values values). - num_padding_values = max_length - cur_len - 1 - mask = _initialize_attention(model_kwargs, num_padding_values, is_encoder_decoder) - new_past = _initialize_past(past_key_values, num_padding_values, batch_axis) - else: - # The new index of past_key_values to be filled corresponds to the current length of the sequence, with two - # subtractions: -1 because past_key_values holds information regarding previous generation steps (read comment above) - # and -1 again because in an array the index is the length of the array minus 1. - new_past_index = cur_len - 2 - mask = _update_attention(model_kwargs, new_past_index, is_encoder_decoder) - new_past = _update_past(past_key_values, new_past_index, batch_axis) - - # sets the updated variables (mask and past_key_values) - model_kwargs.update(mask) - model_kwargs["past_key_values"] = tuple(new_past) - - return model_kwargs - - def _get_logits_warper( - self, - generation_config: GenerationConfig, - ) -> TFLogitsProcessorList: - """ - This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsWarper`] - instances used for multinomial sampling. - """ - - # instantiate warpers list - warpers = TFLogitsProcessorList() - - # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a - # better score (i.e. keep len(generation_config.eos_token_id) + 1) - if generation_config.num_beams > 1: - if isinstance(generation_config.eos_token_id, list): - min_tokens_to_keep = len(generation_config.eos_token_id) + 1 - else: - min_tokens_to_keep = 2 - else: - min_tokens_to_keep = 1 - - if generation_config.temperature is not None and generation_config.temperature != 1.0: - warpers.append(TFTemperatureLogitsWarper(generation_config.temperature)) - if generation_config.top_k is not None and generation_config.top_k != 0: - warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) - if generation_config.top_p is not None and generation_config.top_p < 1.0: - warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) - return warpers - - def _get_logits_processor( - self, - generation_config: GenerationConfig, - input_ids_seq_length: int, - logits_processor: Optional[TFLogitsProcessorList], - ) -> TFLogitsProcessorList: - """ - This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`] - instances used to modify the scores of the language model head. - """ - processors = TFLogitsProcessorList() - - # instantiate processors list - if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: - processors.append(TFRepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) - if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: - processors.append(TFNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) - if generation_config.bad_words_ids is not None: - processors.append( - TFNoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) - ) - if ( - generation_config.min_length is not None - and generation_config.eos_token_id is not None - and generation_config.min_length > 0 - ): - processors.append(TFMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) - if generation_config.forced_bos_token_id is not None: - processors.append(TFForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) - if generation_config.forced_eos_token_id is not None: - processors.append( - TFForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) - ) - if generation_config.suppress_tokens is not None: - processors.append(TFSuppressTokensLogitsProcessor(generation_config.suppress_tokens)) - if generation_config.begin_suppress_tokens is not None: - begin_index = input_ids_seq_length - begin_index = ( - begin_index - if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) - else begin_index + 1 - ) - if getattr(generation_config, "forced_decoder_ids", None) is not None: - begin_index += generation_config.forced_decoder_ids[-1][ - 0 - ] # generation starts after the last token that is forced - processors.append( - TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) - ) - if getattr(generation_config, "forced_decoder_ids", None) is not None: - processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) - - processors = self._merge_criteria_processor_list(processors, logits_processor) - return processors - - def _merge_criteria_processor_list( - self, - default_list: TFLogitsProcessorList, - custom_list: TFLogitsProcessorList, - ) -> TFLogitsProcessorList: - if len(custom_list) == 0: - return default_list - for default in default_list: - for custom in custom_list: - if type(custom) is type(default): - object_type = "logits processor" - raise ValueError( - f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" - f" `generate`, but it has already been created with the values {default}. {default} has been" - " created by passing the corresponding arguments to generate or by the model's config default" - f" values. If you just want to change the default values of {object_type} consider passing" - f" them as arguments to `generate` instead of using a custom {object_type}." - ) - default_list.extend(custom_list) - return default_list - - def greedy_search( - self, - input_ids: tf.Tensor, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - logits_processor: Optional[TFLogitsProcessorList] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - **model_kwargs, - ) -> Union[TFGreedySearchOutput, tf.Tensor]: - r""" - Generates sequences for models with a language modeling head using greedy decoding. - - Parameters: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - logits_processor (`TFLogitsProcessorList`, *optional*): - An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - max_length (`int`, *optional*, defaults to 20): - The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, list[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - model_kwargs: - Additional model specific keyword arguments will be forwarded to the `call` function of the model. If - model is an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.TFGreedySearchDecoderOnlyOutput`], [`~generation.TFGreedySearchEncoderDecoderOutput`] or - `tf.Tensor`: A `tf.Tensor` containing the generated tokens (default behaviour) or a - [`~generation.TFGreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.TFGreedySearchEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... TFAutoModelForCausalLM, - ... TFLogitsProcessorList, - ... TFMinLengthLogitsProcessor, - ... ) - - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - >>> model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2") - - >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token - >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id - - >>> input_prompt = "Today is a beautiful day, and" - >>> input_ids = tokenizer(input_prompt, return_tensors="tf").input_ids - - >>> # instantiate logits processors - >>> logits_processor = TFLogitsProcessorList( - ... [ - ... TFMinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), - ... ] - ... ) - - >>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["Today is a beautiful day, and I'm so happy to be here. I'm so happy to"] - ```""" - - # 1. init greedy_search values - logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() - - max_length = max_length if max_length is not None else self.generation_config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - use_cache = model_kwargs.pop("use_cache", self.generation_config.use_cache) - use_xla = not tf.executing_eagerly() - # TODO (Joao): fix cache format or find programmatic way to detect cache index - # GPT2 and other models has a slightly different cache structure, with a different batch axis - model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) - cache_batch_axis = 1 if any(model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")) else 0 - # some models, like XLNet, need more than the last token in the presence of past_key_values - needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys()) - - # 2. init `attentions`, `hidden_states`, and `scores` tuples - scores = [] if (return_dict_in_generate and output_scores) else None - decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None - cross_attentions = [] if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None - - # 3. init tensors to use for "xla-compileable" generate function - batch_size, cur_len = shape_list(input_ids) - - # initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences` - input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) - generated = tf.concat([input_ids, input_ids_padding], axis=-1) - finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) - - # 4. define "xla-compile-able" stop-condition and auto-regressive function - # define condition fn - def greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs): - """state termination condition fn.""" - return ~tf.reduce_all(finished_sequences) - - # define condition fn - def greedy_search_body_fn(generated, finished_sequences, cur_len, model_kwargs): - """state update fn.""" - if model_kwargs.get("past_key_values") is None or needs_full_input: - input_ids = generated[:, :cur_len] - else: - input_ids = tf.expand_dims(generated[:, cur_len - 1], -1) - model_inputs = self.prepare_inputs_for_generation(input_ids, use_cache=use_cache, **model_kwargs) - # forward pass to get next token logits - model_outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - next_token_logits = model_outputs.logits[:, -1] - - # pre-process distribution - next_tokens_scores = logits_processor(generated, next_token_logits, cur_len) - - # Store scores, attentions and hidden_states when required - if not use_xla and return_dict_in_generate: - if output_scores: - scores.append(next_tokens_scores) - if output_attentions and self.config.is_encoder_decoder: - decoder_attentions.append(model_outputs.decoder_attentions) - elif output_attentions and not self.config.is_encoder_decoder: - decoder_attentions.append(model_outputs.attentions) - if self.config.is_encoder_decoder: - cross_attentions.append(model_outputs.cross_attentions) - - if output_hidden_states and self.config.is_encoder_decoder: - decoder_hidden_states.append(model_outputs.decoder_hidden_states) - elif output_hidden_states and self.config.is_encoder_decoder: - decoder_hidden_states.append(model_outputs.hidden_states) - - # argmax - next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32) - - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32) - next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq) - next_token_is_eos = tf.math.reduce_any( - tf.equal( - tf.broadcast_to(next_tokens, (len(eos_token_id), batch_size)), tf.expand_dims(eos_token_id, -1) - ), - axis=0, - ) - finished_sequences = finished_sequences | next_token_is_eos - - # update `generated` and `cur_len` - update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1) - generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens) - cur_len += 1 - - # update model_kwargs - if use_xla: - model_kwargs = self._update_model_kwargs_for_xla_generation( - model_outputs=model_outputs, - model_kwargs=model_kwargs, - cur_len=cur_len, - max_length=max_length, - batch_size=batch_size, - is_encoder_decoder=self.config.is_encoder_decoder, - batch_axis=cache_batch_axis, - ) - else: - model_kwargs = self._update_model_kwargs_for_generation( - model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - # if we don't cache past_key_values key values we need the whole input - if model_kwargs.get("past_key_values", None) is None: - # let's throw out `past_key_values` since we don't want `None` tensors - model_kwargs.pop("past_key_values", None) - - return generated, finished_sequences, cur_len, model_kwargs - - # 5. run generation - # 1st generation step has to be run before to initialize `past_key_values` - generated, finished_sequences, cur_len, model_kwargs = greedy_search_body_fn( - generated, finished_sequences, cur_len, model_kwargs - ) - - # 2-to-n generation steps can then be run in autoregressive fashion - # only in case 1st generation step does NOT yield EOS token though - maximum_iterations = max_length - cur_len - generated, _, cur_len, _ = tf.while_loop( - greedy_search_cond_fn, - greedy_search_body_fn, - (generated, finished_sequences, cur_len, model_kwargs), - maximum_iterations=maximum_iterations, - ) - - # 6. prepare outputs - if not use_xla: - # cut for backward compatibility - generated = generated[:, :cur_len] - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - # if model is an encoder-decoder, retrieve encoder attention weights - # and hidden states - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - scores = tuple(scores) if scores is not None else None - decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None - cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None - decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None - - return TFGreedySearchEncoderDecoderOutput( - sequences=generated, - scores=scores, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return TFGreedySearchDecoderOnlyOutput( - sequences=generated, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return generated - - def sample( - self, - input_ids: tf.Tensor, - logits_processor: Optional[TFLogitsProcessorList] = None, - logits_warper: Optional[TFLogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - seed: Optional[tuple[int, int]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - **model_kwargs, - ) -> Union[TFSampleOutput, tf.Tensor]: - r""" - Generates sequences for models with a language modeling head using multinomial sampling. - - Parameters: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - logits_processor (`TFLogitsProcessorList`, *optional*): - An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - logits_warper (`TFLogitsProcessorList`, *optional*): - An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`] - used to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - max_length (`int`, *optional*, defaults to 20): - The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, list[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - seed (`list[int]`, *optional*): - Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the - `seed` argument from stateless functions in `tf.random`. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - model_kwargs: - Additional model specific kwargs will be forwarded to the `call` function of the model. If model is an - encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.TFSampleDecoderOnlyOutput`], [`~generation.TFSampleEncoderDecoderOutput`] or `tf.Tensor`: A - `tf.Tensor` containing the generated tokens (default behaviour) or a - [`~generation.TFSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.TFSampleEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> import tensorflow as tf - >>> from transformers import ( - ... AutoTokenizer, - ... TFAutoModelForCausalLM, - ... TFLogitsProcessorList, - ... TFMinLengthLogitsProcessor, - ... TFTopKLogitsWarper, - ... TFTemperatureLogitsWarper, - ... ) - - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - >>> model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2") - - >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token - >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id - - >>> input_prompt = "Today is a beautiful day, and" - >>> input_ids = tokenizer(input_prompt, return_tensors="tf").input_ids - - >>> # instantiate logits processors - >>> logits_processor = TFLogitsProcessorList( - ... [ - ... TFMinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), - ... ] - ... ) - >>> # instantiate logits processors - >>> logits_warper = TFLogitsProcessorList( - ... [ - ... TFTopKLogitsWarper(50), - ... TFTemperatureLogitsWarper(0.7), - ... ] - ... ) - - >>> tf.random.set_seed(0) - >>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Today is a beautiful day, and I love my country. But when I look at Donald Trump,'] - ```""" - - # 1. init greedy_search values - logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() - logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() - - max_length = max_length if max_length is not None else self.generation_config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - use_cache = model_kwargs.pop("use_cache", self.generation_config.use_cache) - use_xla = not tf.executing_eagerly() - # TODO (Joao): fix cache format or find programmatic way to detect cache index - # GPT2 and other models has a slightly different cache structure, with a different batch axis - model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) - cache_batch_axis = 1 if any(model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")) else 0 - # some models, like XLNet, need more than the last token in the presence of past_key_values - needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys()) - - # 2. init `attentions`, `hidden_states`, and `scores` tuples - scores = [] if (return_dict_in_generate and output_scores) else None - decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None - cross_attentions = [] if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None - - # 3. init tensors to use for "xla-compileable" generate function - batch_size, cur_len = shape_list(input_ids) - - # initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences` - input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) - generated = tf.concat([input_ids, input_ids_padding], axis=-1) - finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) - - # 4. define "xla-compile-able" stop-condition and auto-regressive function - def sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs): - return ~tf.reduce_all(finished_sequences) - - def sample_body_fn(generated, finished_sequences, cur_len, model_kwargs): - if model_kwargs.get("past_key_values") is None or needs_full_input: - input_ids = generated[:, :cur_len] - else: - input_ids = tf.expand_dims(generated[:, cur_len - 1], -1) - model_inputs = self.prepare_inputs_for_generation(input_ids, use_cache=use_cache, **model_kwargs) - # forward pass to get next token logits - model_outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - next_token_logits = model_outputs.logits[:, -1] - - # pre-process distribution - next_tokens_scores = logits_processor(generated, next_token_logits, cur_len) - next_tokens_scores = logits_warper(generated, next_tokens_scores, cur_len) - - # Store scores, attentions and hidden_states when required - if not use_xla and return_dict_in_generate: - if output_scores: - scores.append(next_tokens_scores) - if output_attentions and self.config.is_encoder_decoder: - decoder_attentions.append(model_outputs.decoder_attentions) - elif output_attentions and not self.config.is_encoder_decoder: - decoder_attentions.append(model_outputs.attentions) - if self.config.is_encoder_decoder: - cross_attentions.append(model_outputs.cross_attentions) - - if output_hidden_states and self.config.is_encoder_decoder: - decoder_hidden_states.append(model_outputs.decoder_hidden_states) - elif output_hidden_states and self.config.is_encoder_decoder: - decoder_hidden_states.append(model_outputs.hidden_states) - - # sample - if seed is not None: - sample_seed = seed - else: - sample_seed = tf.experimental.numpy.random.randint(tf.int32.min, tf.int32.max, (2,), dtype=tf.int32) - next_tokens = tf.squeeze( - tf.random.stateless_categorical( - logits=next_tokens_scores, num_samples=1, seed=sample_seed, dtype=tf.int32 - ), - axis=1, - ) - - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32) - next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq) - next_token_is_eos = tf.math.reduce_any( - tf.equal( - tf.broadcast_to(next_tokens, (len(eos_token_id), batch_size)), tf.expand_dims(eos_token_id, -1) - ), - axis=0, - ) - finished_sequences = finished_sequences | next_token_is_eos - - # update `generated` and `cur_len` - update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1) - generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens) - cur_len += 1 - - # update model_kwargs - if use_xla: - model_kwargs = self._update_model_kwargs_for_xla_generation( - model_outputs=model_outputs, - model_kwargs=model_kwargs, - cur_len=cur_len, - max_length=max_length, - batch_size=batch_size, - is_encoder_decoder=self.config.is_encoder_decoder, - batch_axis=cache_batch_axis, - ) - else: - model_kwargs = self._update_model_kwargs_for_generation( - model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - # if we don't cache past_key_values key values we need the whole input - if model_kwargs.get("past_key_values", None) is None: - # let's throw out `past_key_values` since we don't want `None` tensors - model_kwargs.pop("past_key_values", None) - - return generated, finished_sequences, cur_len, model_kwargs - - # 5. run generation - # 1st generation step has to be run before to initialize `past_key_values` - generated, finished_sequences, cur_len, model_kwargs = sample_body_fn( - generated, finished_sequences, cur_len, model_kwargs - ) - - # 2-to-n generation steps can then be run in autoregressive fashion - # only in case 1st generation step does NOT yield EOS token though - maximum_iterations = max_length - cur_len - generated, _, cur_len, _ = tf.while_loop( - sample_cond_fn, - sample_body_fn, - (generated, finished_sequences, cur_len, model_kwargs), - maximum_iterations=maximum_iterations, - ) - - # 6. prepare outputs - if not use_xla: - # cut for backward compatibility - generated = generated[:, :cur_len] - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - # if model is an encoder-decoder, retrieve encoder attention weights - # and hidden states - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - scores = tuple(scores) if scores is not None else None - decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None - cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None - decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None - - return TFSampleEncoderDecoderOutput( - sequences=generated, - scores=scores, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return TFSampleDecoderOnlyOutput( - sequences=generated, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return generated - - @staticmethod - def _gather_beams(nested, beam_indices, batch_axis=0): - """Gathers the beam slices indexed by beam_indices into new beam array.""" - - def gather_fn(tensor): - if batch_axis > 0: - # pushes all dimensions before the batch to the end, so we get (batch, beam_id, ...) - perm = tf.concat((tf.range(tf.rank(tensor))[batch_axis:], tf.range(batch_axis)), axis=0) - tensor = tf.transpose(tensor, perm=perm) - - gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1) - if batch_axis > 0: - # transposes back to the original dimensions - perm = tf.concat((tf.range(tf.rank(tensor))[batch_axis:], tf.range(batch_axis)), axis=0) - perm = tf.math.invert_permutation(perm) - gathered_tensor = tf.transpose(gathered_tensor, perm=perm) - - return gathered_tensor - - return tf.nest.map_structure(gather_fn, nested) - - def beam_search( - self, - input_ids: tf.Tensor, - do_sample: bool = False, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - length_penalty: Optional[float] = None, - early_stopping: Optional[Union[bool, str]] = None, - logits_processor: Optional[TFLogitsProcessorList] = None, - logits_warper: Optional[TFLogitsProcessorList] = None, - num_return_sequences: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - **model_kwargs, - ) -> Union[TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]: - r""" - Generates sequences for models with a language modeling head using beam search. If `do_sample` is `False`, uses - a greedy approach, otherwise does multinomial sampling without replacement. - - Parameters: - input_ids (`tf.Tensor` of shape `(batch_size, num_beams, sequence_length)`): - The sequence used as a prompt for the generation. - do_sample (`bool`, *optional*, defaults to `False`): - Whether or not to use sampling ; use greedy decoding otherwise. - max_length (`int`, *optional*, defaults to 20): - The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, list[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - length_penalty (`float`, *optional*, defaults to 1.0): - Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent - to the sequence length, which in turn is used to divide the score of the sequence. Since the score is - the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, - while `length_penalty` < 0.0 encourages shorter sequences. - early_stopping (`bool` or `str`, *optional*, defaults to `False`): - Controls the stopping condition for beam-based methods, like beam-search. It accepts the following - values: `True`, where the generation stops as soon as there are `num_beams` complete candidates; - `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better - candidates; `"never"`, where the beam search procedure only stops when there cannot be better - candidates (canonical beam search algorithm). - logits_processor (`[TFLogitsProcessorList]`, *optional*): - An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - logits_warper (`TFLogitsProcessorList`, *optional*): - An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`] - used to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - num_return_sequences(`int`, *optional*, defaults to 1): - The number of independently computed returned sequences for each element in the batch. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. - model_kwargs: - Additional model specific kwargs will be forwarded to the `call` function of the model. If model is an - encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.TFBeamSearchDecoderOnlyOutput`], [`~generation.TFBeamSearchEncoderDecoderOutput`] or - `tf.Tensor`: A `tf.Tensor` containing the generated tokens (default behaviour) or a - [`~generation.TFBeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.TFBeamSearchEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... TFAutoModelForSeq2SeqLM, - ... TFLogitsProcessorList, - ... TFMinLengthLogitsProcessor, - ... ) - >>> import tensorflow as tf - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") - >>> model = TFAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="tf").input_ids - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = tf.ones((1, num_beams, 1), dtype=tf.int32) - >>> input_ids = input_ids * model.generation_config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> encoder_outputs = model.get_encoder()(encoder_input_ids, return_dict=True) - >>> encoder_outputs.last_hidden_state = tf.repeat( - ... tf.expand_dims(encoder_outputs.last_hidden_state, axis=0), num_beams, axis=1 - ... ) - >>> model_kwargs = {"encoder_outputs": encoder_outputs} - - >>> # instantiate logits processors - >>> logits_processor = TFLogitsProcessorList( - ... [TFMinLengthLogitsProcessor(5, eos_token_id=model.generation_config.eos_token_id)] - ... ) - - >>> outputs = model.beam_search(input_ids, logits_processor=logits_processor, **model_kwargs) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" - - def flatten_beam_dim(tensor, batch_axis=0): - """Flattens the first two dimensions of a non-scalar array.""" - shape = shape_list(tensor) - return tf.reshape( - tensor, - shape[:batch_axis] + [shape[batch_axis] * shape[batch_axis + 1]] + shape[batch_axis + 2 :], - ) - - def unflatten_beam_dim(tensor, num_beams, batch_axis=0): - """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" - shape = shape_list(tensor) - return tf.reshape(tensor, shape[:batch_axis] + [-1, num_beams] + shape[batch_axis + 1 :]) - - # 1. init beam_search values - logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() - logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() - - max_length = max_length if max_length is not None else self.generation_config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - num_return_sequences = ( - num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences - ) - - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - - length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty - early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping - - use_cache = model_kwargs.pop("use_cache", self.generation_config.use_cache) - use_xla = not tf.executing_eagerly() - # TODO (Joao): fix cache format or find programmatic way to detect cache index - # GPT2 and other models has a slightly different cache structure, with a different batch axis - model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) - cache_batch_axis = 1 if any(model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")) else 0 - # some models, like XLNet, need more than the last token in the presence of past_key_values - needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys()) - - # 2. init `attentions`, `hidden_states`, and `scores` tuples - all_scores = [] if (return_dict_in_generate and output_scores) else None - decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None - cross_attentions = [] if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None - - # 3. init tensors to use for "xla-compileable" generate function - batch_size, num_beams, cur_len = shape_list(input_ids) - # store the prompt length of decoder - decoder_prompt_len = cur_len - - # per batch, beam-item holding current token in loop, pre-populated with `pad_token_id` - input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * ( - pad_token_id or 0 - ) - running_sequences = tf.concat([input_ids, input_ids_padding], axis=-1) - sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * (pad_token_id or 0) - - # per batch,beam-item state bit indicating if sentence has finished. - is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool) - - # per batch, beam-item score, logprobs - running_scores = tf.tile( - tf.expand_dims(tf.convert_to_tensor([0.0] + [-1.0e9] * (num_beams - 1)), axis=0), [batch_size, 1] - ) - scores = tf.ones((batch_size, num_beams)) * -1.0e9 - - # per batch beam indices - running_beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1 - beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1 - - # flatten beam dim - if "encoder_outputs" in model_kwargs: - model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( - model_kwargs["encoder_outputs"]["last_hidden_state"] - ) - if "attention_mask" in model_kwargs: - model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"]) - - # 4. define "xla-compile-able" stop-condition and auto-regressive function - # define stop-condition and auto-regressive function - def beam_search_cond_fn( - cur_len, - running_sequences, - running_scores, - running_beam_indices, - sequences, - scores, - beam_indices, - is_sent_finished, - decoder_prompt_len, - model_kwargs, - ): - """ - Beam Search termination condition function -- halts the generation loop if any of these conditions becomes - False - """ - # 1. is less than max length? - not_max_length_yet = cur_len < max_length - - # 2. can the new beams still improve? - # early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`. See the discussion - # below for more details. - # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 - # early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of - # length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there. - if early_stopping == "never" and length_penalty > 0.0: - best_running_score = running_scores[:, :1] / ((max_length - decoder_prompt_len) ** length_penalty) - else: - best_running_score = running_scores[:, :1] / ( - tf.cast(cur_len - decoder_prompt_len, dtype=tf.float32) ** length_penalty - ) - worst_finished_score = tf.where( - is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9 - ) - improvement_still_possible = tf.math.reduce_any(best_running_score > worst_finished_score) - - # 3. is there still a beam that has not finished? - still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & (early_stopping is True)) - - return not_max_length_yet & still_open_beam & improvement_still_possible - - def beam_search_body_fn( - cur_len, - running_sequences, - running_scores, - running_beam_indices, - sequences, - scores, - beam_indices, - is_sent_finished, - decoder_prompt_len, - model_kwargs, - ): - """ - Beam Search iterative update function -- each iteration adds a new token and updates the best sequences - seen so far - """ - # 1. Forward current tokens - if model_kwargs.get("past_key_values") is None or needs_full_input: - input_ids = running_sequences[:, :, :cur_len] - else: - input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1) - model_inputs = self.prepare_inputs_for_generation( - flatten_beam_dim(input_ids), use_cache=use_cache, **model_kwargs - ) - model_outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - logits = unflatten_beam_dim(model_outputs.logits[:, -1], num_beams) - - # 2. Compute log probs - # get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and - # add new logprobs to existing running logprobs scores. - log_probs = tf.nn.log_softmax(logits) - log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len) - log_probs = unflatten_beam_dim(log_probs, num_beams) - if do_sample: - log_probs = logits_warper(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len) - log_probs = unflatten_beam_dim(log_probs, num_beams) - log_probs_processed = log_probs - log_probs = log_probs + tf.expand_dims(running_scores, axis=2) - vocab_size = log_probs.shape[2] - log_probs = tf.reshape(log_probs, (batch_size, num_beams * vocab_size)) - - # Store scores, attentions and hidden_states when required - if not use_xla and return_dict_in_generate: - if output_scores: - all_scores.append( - logits_warper( - flatten_beam_dim(running_sequences), - flatten_beam_dim(log_probs_processed), - cur_len, - ) - ) - if output_attentions and self.config.is_encoder_decoder: - decoder_attentions.append(model_outputs.decoder_attentions) - elif output_attentions and not self.config.is_encoder_decoder: - decoder_attentions.append(model_outputs.attentions) - if self.config.is_encoder_decoder: - cross_attentions.append(model_outputs.cross_attentions) - - if output_hidden_states and self.config.is_encoder_decoder: - decoder_hidden_states.append(model_outputs.decoder_hidden_states) - elif output_hidden_states and self.config.is_encoder_decoder: - decoder_hidden_states.append(model_outputs.hidden_states) - - # 3. Retrieve top-K - # Each item in batch has num_beams * vocab_size candidate sequences. For each item, get the top 2*k - # candidates with the highest log-probabilities. We gather the top 2*K beams here so that even if the - # best K sequences reach EOS simultaneously, we have another K sequences remaining to continue the live - # beam search. - # Gather the top 2*K scores from _all_ beams. - # Gather 2*k top beams. - # Recover the beam index by floor division. - # Recover token id by modulo division and expand Id array for broadcasting. - # Update sequences for the 2*K top-k new sequences. - beams_to_keep = 2 * num_beams - if do_sample: - topk_indices = sample_without_replacement(log_probs, beams_to_keep) - topk_log_probs = tf.gather(log_probs, topk_indices, axis=1, batch_dims=1) - else: - topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep) - topk_current_beam_indices = topk_indices // vocab_size - topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices) - topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices) - topk_ids = topk_indices % vocab_size - - # writes the new token - indices_batch = tf.repeat(tf.range(batch_size), [beams_to_keep]) - indices_beam = tf.tile(tf.range(beams_to_keep), [batch_size]) - update_indices = tf.stack( - [indices_batch, indices_beam, tf.broadcast_to(cur_len, [batch_size * beams_to_keep])], axis=-1 - ) - topk_sequences = tf.tensor_scatter_nd_update( - tensor=topk_running_sequences, - indices=update_indices, - updates=tf.reshape(topk_ids, [batch_size * beams_to_keep]), - ) - - # we want to store the beam indices with batch information -> real beam index = beam index % num beams - batch_modified_indices = topk_current_beam_indices + tf.broadcast_to( - tf.expand_dims(tf.range(batch_size) * num_beams, axis=1), topk_current_beam_indices.shape - ) - update_indices = tf.stack( - [ - indices_batch, - indices_beam, - tf.broadcast_to(cur_len - decoder_prompt_len, [batch_size * beams_to_keep]), - ], - axis=-1, - ) - topk_beam_indices = tf.tensor_scatter_nd_update( - tensor=topk_running_beam_indices, - indices=update_indices, - updates=tf.reshape(batch_modified_indices, [batch_size * beams_to_keep]), - ) - - # 4. Check which sequences have ended - # Update current sequences: Did the top `num_beams` sequences reach an end marker? - # To prevent these just finished sequences from being added to the current sequences - # set of active beam search sequences, set their log probs to a very large negative value. - if eos_token_id is None: - eos_in_next_token = tf.zeros(topk_sequences[:, :, cur_len].shape, dtype=tf.bool) - else: - eos_in_next_token = tf.math.reduce_any( - tf.equal( - tf.broadcast_to( - topk_sequences[:, :, cur_len], - [len(eos_token_id)] + topk_sequences[:, :, cur_len].shape, - ), - tf.expand_dims(tf.expand_dims(eos_token_id, -1), -1), - ), - axis=0, - ) - did_topk_just_finished = eos_in_next_token & tf.broadcast_to( - tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0), - shape_list(eos_in_next_token), - ) - - # non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next - # running sentences either - running_topk_log_probs = topk_log_probs + tf.cast(eos_in_next_token, tf.float32) * -1.0e9 - - # 5. Get running sequences scores for next - # Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams - # (from top 2*k beams). - next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1] - next_running_sequences, next_running_scores, next_running_beam_indices = self._gather_beams( - [topk_sequences, running_topk_log_probs, topk_beam_indices], next_topk_indices - ) - - # 6. Process topk logits - # Further process log probs: - # - add length penalty - # - make sure no scores can be added anymore if beam is full - # - make sure still running sequences cannot be chosen as finalized beam - topk_log_probs = topk_log_probs / ( - tf.cast(cur_len + 1 - decoder_prompt_len, dtype=tf.float32) ** length_penalty - ) - beams_in_batch_are_full = tf.broadcast_to( - tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished) - ) & (early_stopping is True) - add_penalty = ~did_topk_just_finished | beams_in_batch_are_full - topk_log_probs += tf.cast(add_penalty, tf.float32) * -1.0e9 - - # 7. Get scores, sequences, is sentence finished for next. - # Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores - # to existing finished scores and select the best from the new set of beams - merged_sequences = tf.concat([sequences, topk_sequences], axis=1) - merged_scores = tf.concat([scores, topk_log_probs], axis=1) - merged_beams = tf.concat([beam_indices, topk_beam_indices], axis=1) - merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1) - topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1] - next_sequences, next_scores, next_beam_indices, next_is_sent_finished = self._gather_beams( - [merged_sequences, merged_scores, merged_beams, merged_is_sent_finished], topk_merged_indices - ) - - # 8. Prepare data for the next iteration - # Determine the top k beam indices from the original set of all beams. With these, gather the top k - # beam-associated caches. - cur_len = cur_len + 1 - if "past_key_values" in model_outputs: - cache = tf.nest.map_structure( - lambda tensor: unflatten_beam_dim(tensor, num_beams, batch_axis=cache_batch_axis), - model_outputs.past_key_values, - ) - next_running_indices = self._gather_beams(topk_current_beam_indices, next_topk_indices) - next_cache = self._gather_beams(cache, next_running_indices, batch_axis=cache_batch_axis) - model_outputs["past_key_values"] = tf.nest.map_structure( - lambda tensor: flatten_beam_dim(tensor, batch_axis=cache_batch_axis), next_cache - ) - - if use_xla: - next_model_kwargs = self._update_model_kwargs_for_xla_generation( - model_outputs=model_outputs, - model_kwargs=model_kwargs, - cur_len=cur_len, - max_length=max_length, - batch_size=(batch_size * num_beams), - is_encoder_decoder=self.config.is_encoder_decoder, - batch_axis=cache_batch_axis, - ) - else: - next_model_kwargs = self._update_model_kwargs_for_generation( - model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - - # if we don't cache past_key_values key values we need the whole input - if model_kwargs.get("past_key_values", None) is None: - # let's throw out `past_key_values` since we don't want `None` tensors - model_kwargs.pop("past_key_values", None) - - return ( - cur_len, - next_running_sequences, - next_running_scores, - next_running_beam_indices, - next_sequences, - next_scores, - next_beam_indices, - next_is_sent_finished, - decoder_prompt_len, - next_model_kwargs, - ) - - # 5. run generation - # 1st generation step has to be run before to initialize `past_key_values` (if active) - ( - cur_len, - running_sequences, - running_scores, - running_beam_indices, - sequences, - scores, - beam_indices, - is_sent_finished, - decoder_prompt_len, - model_kwargs, - ) = beam_search_body_fn( - cur_len, - running_sequences, - running_scores, - running_beam_indices, - sequences, - scores, - beam_indices, - is_sent_finished, - decoder_prompt_len, - model_kwargs, - ) - - # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does - # NOT yield EOS token though) - maximum_iterations = max_length - cur_len - ( - cur_len, - running_sequences, - running_scores, - running_beam_indices, - sequences, - scores, - beam_indices, - is_sent_finished, - decoder_prompt_len, - _, - ) = tf.while_loop( - beam_search_cond_fn, - beam_search_body_fn, - ( - cur_len, - running_sequences, - running_scores, - running_beam_indices, - sequences, - scores, - beam_indices, - is_sent_finished, - decoder_prompt_len, - model_kwargs, - ), - maximum_iterations=maximum_iterations, - ) - - # 6. prepare outputs - # Account for the edge-case where there are no finished sequences for a particular batch item. If so, return - # running sequences for that batch item. - none_finished = tf.math.reduce_any(is_sent_finished, axis=1) - sequences = tf.where(none_finished[:, None, None], sequences, running_sequences) - beam_indices = tf.where(none_finished[:, None, None], beam_indices, running_beam_indices) - - # Apply the length penalty so that running scores match the finalized scores if they are used - running_scores = running_scores / (tf.cast(cur_len - decoder_prompt_len, dtype=tf.float32) ** length_penalty) - scores = tf.where(none_finished[:, None], scores, running_scores) - - # Take best beams for each batch (the score is sorted in descending order) - sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :]) - scores = flatten_beam_dim(scores[:, :num_return_sequences]) - beam_indices = flatten_beam_dim(beam_indices[:, :num_return_sequences, :]) - - if not use_xla: - # Cut for backward compatibility - sequences = sequences[:, :cur_len] - beam_indices = beam_indices[:, : cur_len - decoder_prompt_len] - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - output_cls = TFBeamSampleEncoderDecoderOutput if do_sample else TFBeamSearchEncoderDecoderOutput - return output_cls( - sequences=sequences, - sequences_scores=scores, - scores=all_scores, - beam_indices=beam_indices, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - output_cls = TFBeamSampleDecoderOnlyOutput if do_sample else TFBeamSearchDecoderOnlyOutput - return output_cls( - sequences=sequences, - sequences_scores=scores, - scores=all_scores, - beam_indices=beam_indices, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return sequences - - def contrastive_search( - self, - input_ids: tf.Tensor, - top_k: Optional[int] = 1, - penalty_alpha: Optional[float] = 0, - logits_processor: Optional[TFLogitsProcessorList] = None, - logits_warper: Optional[TFLogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - **model_kwargs, - ) -> Union[TFContrastiveSearchOutput, tf.Tensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **contrastive search** and can - be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - top_k (`int`, *optional*, defaults to 1): - The size of the candidate set that is used to re-rank for contrastive search - penalty_alpha (`float`, *optional*, defaults to 0): - The degeneration penalty for contrastive search; activate when it is larger than 0 - logits_processor (`TFLogitsProcessorList`, *optional*): - An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - logits_warper (`TFLogitsProcessorList`, *optional*): - An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`] - used to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - max_length (`int`, *optional*, defaults to 20): - The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, list[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - model_kwargs: - Additional model specific keyword arguments will be forwarded to the `call` function of the model. If - model is an encoder-decoder model the kwargs should include `encoder_outputs`. - Return: - [`~generation.TFContrastiveSearchDecoderOnlyOutput`], - [`~generation.TFContrastiveSearchEncoderDecoderOutput`] or `tf.Tensor`: A `tf.Tensor` containing the - generated tokens (default behaviour) or a [`~generation.TFContrastiveySearchDecoderOnlyOutput`] if - `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a - [`~generation.TFContrastiveSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. - Examples: - ```python - >>> from transformers import AutoTokenizer, TFAutoModelForCausalLM - - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - >>> model = TFAutoModelForCausalLM.from_pretrained("facebook/opt-125m") - >>> # set pad_token_id to eos_token_id because OPT does not have a PAD token - >>> model.config.pad_token_id = model.config.eos_token_id - >>> input_prompt = "DeepMind Company is" - >>> input_ids = tokenizer(input_prompt, return_tensors="tf") - >>> outputs = model.contrastive_search(**input_ids, penalty_alpha=0.6, top_k=4, max_length=64) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['DeepMind Company is a company that focuses on the development and commercialization of artificial intelligence (AI). DeepMind’s mission is to help people understand and solve problems that are difficult to solve in the world today.\n\nIn this post, we talk about the benefits of deep learning in business and how it'] - ```""" - - def gather_best_candidate(nested, selected_idx_stacked, batch_axis=0): - """Gathers the slices indexed by selected_idx_stacked from a potentially nested structure of tensors.""" - - def gather_fn(tensor): - gathered_tensor = tf.gather(params=tensor, indices=selected_idx_stacked, axis=batch_axis) - return gathered_tensor - - return tf.nest.map_structure(gather_fn, nested) - - # 1. init greedy_search values - logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() - logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() - max_length = max_length if max_length is not None else self.generation_config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - use_cache = True # In contrastive search, we always use cache - model_kwargs.pop("use_cache", None) - - use_xla = not tf.executing_eagerly() - # TODO (Joao): fix cache format or find programmatic way to detect cache index - # GPT2 and other models has a slightly different cache structure, with a different batch axis - model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) - cache_batch_axis = 1 if any(model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")) else 0 - - # 2. init `attentions`, `hidden_states`, and `scores` tuples - scores = [] if (return_dict_in_generate and output_scores) else None - decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None - cross_attentions = [] if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None - - # 3. init tensors to use for "xla-compileable" generate function - batch_size, cur_len = shape_list(input_ids) - - # initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences` - input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) - generated = tf.concat([input_ids, input_ids_padding], axis=-1) - finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) - - # 4. define "xla-compile-able" stop-condition and auto-regressive function - # define condition fn - def contrastive_search_cond_fn( - generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables - ): - """state termination condition fn.""" - return ~tf.reduce_all(finished_sequences) - - # define condition fn - def contrastive_search_body_fn( - generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables - ): - """state update fn.""" - - # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; - # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step - if model_kwargs.get("past_key_values") is None: - # prepare inputs - model_inputs = self.prepare_inputs_for_generation( - generated[:, :cur_len], use_cache=use_cache, **model_kwargs - ) - - # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save - # the `encoder_outputs` - outputs = self( - **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions - ) - - # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with - # previous tokens) - if self.config.is_encoder_decoder: - last_hidden_states = outputs.decoder_hidden_states[-1] - else: - last_hidden_states = outputs.hidden_states[-1] - - # XLA: last_hidden_states normally grows at each step, but in XLA it is padded so as to be used across - # iterations (with fixed shapes) - if use_xla: - last_hidden_states = tf.pad(last_hidden_states, [[0, 0], [0, max_length - cur_len], [0, 0]]) - - # next logit for contrastive search to select top-k candidate tokens - logit_for_next_step = outputs.logits[:, -1, :] - - if use_xla: - model_kwargs = self._update_model_kwargs_for_xla_generation( - model_outputs=outputs, - model_kwargs=model_kwargs, - cur_len=cur_len, - max_length=max_length, - batch_size=batch_size, - is_encoder_decoder=self.config.is_encoder_decoder, - batch_axis=cache_batch_axis, - ) - else: - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - - # Expands model inputs top_k times, for batched forward passes (akin to beam search). - _, model_kwargs = self._expand_inputs_for_generation( - expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs - ) - - past_key_values = model_kwargs.get("past_key_values") - if past_key_values is None: - raise ValueError( - f"{self.__class__.__name__} does not support caching and therefore **can't** be used " - "for contrastive search." - ) - elif ( - not isinstance(past_key_values[0], (tuple, tf.Tensor)) - or past_key_values[0][0].shape[0] != batch_size - ): - raise ValueError( - f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " - "used for contrastive search without further modifications." - ) - else: - logit_for_next_step = next_step_cached_variables["logit_for_next_step"] - last_hidden_states = next_step_cached_variables["last_hidden_states"] - outputs = next_step_cached_variables["outputs"] - - # contrastive_search main logic start: - # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by - # degeneration penalty - - logit_for_next_step = logits_processor(generated, logit_for_next_step, cur_len) - logit_for_next_step = logits_warper(generated, logit_for_next_step, cur_len) - next_probs = stable_softmax(logit_for_next_step, axis=-1) - top_k_probs, top_k_ids = tf.math.top_k(next_probs, k=top_k) - - # Store scores, attentions and hidden_states when required - if not use_xla and return_dict_in_generate: - if output_scores: - scores.append(logit_for_next_step) - if output_attentions and self.config.is_encoder_decoder: - decoder_attentions.append(outputs.decoder_attentions) - elif output_attentions and not self.config.is_encoder_decoder: - decoder_attentions.append(outputs.attentions) - if self.config.is_encoder_decoder: - cross_attentions.append(outputs.cross_attentions) - - if output_hidden_states and self.config.is_encoder_decoder: - decoder_hidden_states.append(outputs.decoder_hidden_states) - elif output_hidden_states and self.config.is_encoder_decoder: - decoder_hidden_states.append(outputs.hidden_states) - - # Replicates the new past_key_values to match the `top_k` candidates - model_kwargs["past_key_values"] = tf.nest.map_structure( - lambda tensor: tf.repeat(tensor, top_k, axis=cache_batch_axis), model_kwargs["past_key_values"] - ) - - # compute the candidate tokens by the language model and collects their hidden_states - next_model_inputs = self.prepare_inputs_for_generation( - tf.reshape(top_k_ids, [-1, 1]), use_cache=use_cache, **model_kwargs - ) - outputs = self( - **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions - ) - next_past_key_values = self._extract_past_from_model_output(outputs) - - logits = outputs.logits[:, -1, :] - # name is different for encoder-decoder and decoder-only models - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states - else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states - context_hidden = tf.repeat(last_hidden_states[:, :cur_len, :], top_k, axis=0) - - # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the - # model confidence - selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) - - # converts indices to a dimension of top_k to the stacked top_k * batch_size dimension, for indexing - # without a need to reshape on tensors that have these two dimensions stacked - selected_idx_stacked = selected_idx + tf.range(selected_idx.shape[0], dtype=tf.int64) * top_k - - # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing - # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores - # (model confidence minus degeneration penalty); (6) decoder hidden_states - next_tokens = tf.gather(top_k_ids, selected_idx, axis=1, batch_dims=1) - next_hidden = gather_best_candidate(next_hidden, selected_idx_stacked) - - # XLA: last_hidden_states normally grows at each step, but in XLA it is padded so as to be used across - # iterations (with fixed shapes) - if use_xla: - last_hidden_states = dynamic_update_slice(last_hidden_states, next_hidden, [0, cur_len, 0]) - else: - last_hidden_states = tf.concat([last_hidden_states, next_hidden], axis=1) - - next_decoder_hidden_states = gather_best_candidate(full_hidden_states, selected_idx_stacked) - next_past_key_values = gather_best_candidate( - next_past_key_values, selected_idx_stacked, batch_axis=cache_batch_axis - ) - logit_for_next_step = gather_best_candidate(logits, selected_idx_stacked) - - # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration - if self.config.is_encoder_decoder: - next_step_cross_attentions = () - next_step_decoder_attentions = () - if output_attentions: - next_step_cross_attentions = gather_best_candidate(outputs.cross_attentions, selected_idx_stacked) - next_step_decoder_attentions = gather_best_candidate( - outputs.decoder_attentions, selected_idx_stacked - ) - outputs = TFSeq2SeqLMOutput( - past_key_values=next_past_key_values, - decoder_hidden_states=next_decoder_hidden_states, - decoder_attentions=next_step_decoder_attentions or None, - cross_attentions=next_step_cross_attentions or None, - ) - else: - next_step_attentions = () - if output_attentions: - next_step_attentions = gather_best_candidate(outputs.attentions, selected_idx_stacked) - outputs = TFCausalLMOutputWithPast( - past_key_values=next_past_key_values, - hidden_states=next_decoder_hidden_states, - attentions=next_step_attentions or None, - ) - # contrastive_search main logic end - - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32) - next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq) - next_token_is_eos = tf.math.reduce_any( - tf.equal( - tf.broadcast_to(next_tokens, (len(eos_token_id), batch_size)), tf.expand_dims(eos_token_id, -1) - ), - axis=0, - ) - finished_sequences = finished_sequences | next_token_is_eos - - # update `generated` and `cur_len` - update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1) - generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens) - cur_len += 1 - - if use_xla: - # NOTE: 1) relative to other generation strategies, contrastive search is always running forward - # passes one step ahead -- hence the `cur_len=cur_len + 1`; 2) the attention mask here is expanded from - # [batch_size, ...] to [batch_size*top_k, ...] -- hence the `batch_size=batch_size * top_k` - model_kwargs = self._update_model_kwargs_for_xla_generation( - model_outputs=outputs, - model_kwargs=model_kwargs, - cur_len=cur_len + 1, - max_length=max_length, - batch_size=batch_size * top_k, - is_encoder_decoder=self.config.is_encoder_decoder, - batch_axis=cache_batch_axis, - ) - else: - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - - next_step_cached_variables = { - "logit_for_next_step": logit_for_next_step, - "last_hidden_states": last_hidden_states, - "outputs": outputs, - } - return generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables - - # 5. run generation - # 1st generation step has to be run before to initialize `past_key_values` - generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables = contrastive_search_body_fn( - generated, finished_sequences, cur_len, model_kwargs, None - ) - - # 2-to-n generation steps can then be run in autoregressive fashion - # only in case 1st generation step does NOT yield EOS token though - maximum_iterations = max_length - cur_len - generated, _, cur_len, _, _ = tf.while_loop( - contrastive_search_cond_fn, - contrastive_search_body_fn, - (generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables), - maximum_iterations=maximum_iterations, - ) - - # 6. prepare outputs - if not use_xla: - # cut for backward compatibility - generated = generated[:, :cur_len] - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - # if model is an encoder-decoder, retrieve encoder attention weights - # and hidden states - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - scores = tuple(scores) if scores is not None else None - decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None - cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None - decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None - - return TFContrastiveSearchEncoderDecoderOutput( - sequences=generated, - scores=scores, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return TFContrastiveSearchDecoderOnlyOutput( - sequences=generated, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return generated - - -def scatter_values_on_batch_indices(values, batch_indices): - shape = shape_list(batch_indices) - # broadcast batch dim to shape - broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1]) - # transform batch_indices to pair_indices - pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0)) - # scatter values to pair indices - return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape) - - -def sample_without_replacement(logits, num_samples): - """ - categorical sampling without replacement is currently not implemented the gumbel-max trick will do for now see - https://github.com/tensorflow/tensorflow/issues/9260 for more info - """ - z = -tf.math.log(-tf.math.log(tf.random.uniform(shape_list(logits), 0, 1))) - _, indices = tf.nn.top_k(logits + z, num_samples) - return indices - - -def _ranking_fast( - context_hidden: tf.Tensor, - next_hidden: tf.Tensor, - next_top_k_probs: tf.Tensor, - alpha: float, - beam_width: int, -) -> tf.Tensor: - """ - Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described - in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each - row in the batch. - """ - norm_context_hidden = context_hidden / tf.norm(context_hidden, axis=2, keepdims=True) - norm_next_hidden = next_hidden / tf.norm(next_hidden, axis=2, keepdims=True) - cosine_matrix = tf.squeeze(tf.linalg.matmul(norm_context_hidden, norm_next_hidden, transpose_b=True), axis=-1) - degeneration_penalty = tf.reduce_max(cosine_matrix, axis=-1) - next_top_k_probs = tf.reshape(next_top_k_probs, shape=[-1]) - contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty - contrastive_score = tf.reshape(contrastive_score, shape=[-1, beam_width]) - selected_idx = tf.argmax(contrastive_score, axis=1) - return selected_idx diff --git a/src/transformers/modeling_flax_outputs.py b/src/transformers/modeling_flax_outputs.py deleted file mode 100644 index 5a25a6059a25..000000000000 --- a/src/transformers/modeling_flax_outputs.py +++ /dev/null @@ -1,700 +0,0 @@ -# Copyright 2021 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import flax -import jax.numpy as jnp - -from .utils import ModelOutput - - -@flax.struct.dataclass -class FlaxBaseModelOutput(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxBaseModelOutputWithNoAttention(ModelOutput): - """ - Base class for model's outputs, with potential hidden states. - - Args: - last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one - for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the - model at the output of each layer plus the optional initial embedding outputs. - """ - - last_hidden_state: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): - Last layer hidden-state after a pooling operation on the spatial dimensions. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one - for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the - model at the output of each layer plus the optional initial embedding outputs. - """ - - last_hidden_state: Optional[jnp.ndarray] = None - pooler_output: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxImageClassifierOutputWithNoAttention(ModelOutput): - """ - Base class for outputs of image classification models. - - Args: - logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when - `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one - for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also - called feature maps) of the model at the output of each stage. - """ - - logits: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxBaseModelOutputWithPast(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - past_key_values (`dict[str, jnp.ndarray]`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: Optional[jnp.ndarray] = None - past_key_values: Optional[dict[str, jnp.ndarray]] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxBaseModelOutputWithPooling(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) further processed by a - Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence - prediction (classification) objective during pretraining. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: Optional[jnp.ndarray] = None - pooler_output: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) after further processing - through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns - the classification token after processing through a linear layer and a tanh activation function. The linear - layer weights are trained from the next sentence prediction (classification) objective during pretraining. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one - for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - """ - - last_hidden_state: Optional[jnp.ndarray] = None - pooler_output: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - cross_attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). - - Args: - last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - """ - - last_hidden_state: Optional[jnp.ndarray] = None - past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - cross_attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxSeq2SeqModelOutput(ModelOutput): - """ - Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential - decoding. - - Args: - last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - last_hidden_state: Optional[jnp.ndarray] = None - past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None - decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None - decoder_attentions: Optional[tuple[jnp.ndarray]] = None - cross_attentions: Optional[tuple[jnp.ndarray]] = None - encoder_last_hidden_state: Optional[jnp.ndarray] = None - encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None - encoder_attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxCausalLMOutputWithCrossAttentions(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Cross attentions weights after the attention softmax, used to compute the weighted average in the - cross-attention heads. - past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `jnp.ndarray` tuples of length `config.n_layers`, with each tuple containing the cached key, value - states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting. - Only relevant if `config.is_decoder = True`. - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - """ - - logits: Optional[jnp.ndarray] = None - past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - cross_attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxMaskedLMOutput(ModelOutput): - """ - Base class for masked language models outputs. - - Args: - logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -FlaxCausalLMOutput = FlaxMaskedLMOutput - - -@flax.struct.dataclass -class FlaxSeq2SeqLMOutput(ModelOutput): - """ - Base class for sequence-to-sequence language models outputs. - - Args: - logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - logits: Optional[jnp.ndarray] = None - past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None - decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None - decoder_attentions: Optional[tuple[jnp.ndarray]] = None - cross_attentions: Optional[tuple[jnp.ndarray]] = None - encoder_last_hidden_state: Optional[jnp.ndarray] = None - encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None - encoder_attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxNextSentencePredictorOutput(ModelOutput): - """ - Base class for outputs of models predicting if two sentences are consecutive or not. - - Args: - logits (`jnp.ndarray` of shape `(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxSequenceClassifierOutput(ModelOutput): - """ - Base class for outputs of sentence classification models. - - Args: - logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput): - """ - Base class for outputs of sequence-to-sequence sentence classification models. - - Args: - logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - logits: Optional[jnp.ndarray] = None - past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None - decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None - decoder_attentions: Optional[tuple[jnp.ndarray]] = None - cross_attentions: Optional[tuple[jnp.ndarray]] = None - encoder_last_hidden_state: Optional[jnp.ndarray] = None - encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None - encoder_attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxMultipleChoiceModelOutput(ModelOutput): - """ - Base class for outputs of multiple choice models. - - Args: - logits (`jnp.ndarray` of shape `(batch_size, num_choices)`): - *num_choices* is the second dimension of the input tensors. (see *input_ids* above). - - Classification scores (before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxTokenClassifierOutput(ModelOutput): - """ - Base class for outputs of token classification models. - - Args: - logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.num_labels)`): - Classification scores (before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxQuestionAnsweringModelOutput(ModelOutput): - """ - Base class for outputs of question answering models. - - Args: - start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Span-start scores (before SoftMax). - end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Span-end scores (before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - start_logits: Optional[jnp.ndarray] = None - end_logits: Optional[jnp.ndarray] = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput): - """ - Base class for outputs of sequence-to-sequence question answering models. - - Args: - start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Span-start scores (before SoftMax). - end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Span-end scores (before SoftMax). - past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - start_logits: Optional[jnp.ndarray] = None - end_logits: Optional[jnp.ndarray] = None - past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None - decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None - decoder_attentions: Optional[tuple[jnp.ndarray]] = None - cross_attentions: Optional[tuple[jnp.ndarray]] = None - encoder_last_hidden_state: Optional[jnp.ndarray] = None - encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None - encoder_attentions: Optional[tuple[jnp.ndarray]] = None diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py deleted file mode 100644 index dece5233d956..000000000000 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ /dev/null @@ -1,491 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch - Flax general utilities.""" - -import os -from pickle import UnpicklingError - -import jax -import jax.numpy as jnp -import numpy as np -from flax.serialization import from_bytes -from flax.traverse_util import flatten_dict, unflatten_dict - -import transformers - -from . import is_safetensors_available, is_torch_available -from .utils import check_torch_load_is_safe, logging - - -if is_torch_available(): - import torch - -if is_safetensors_available(): - from safetensors import safe_open - from safetensors.flax import load_file as safe_load_file - - -logger = logging.get_logger(__name__) - - -##################### -# PyTorch => Flax # -##################### - - -def load_pytorch_checkpoint_in_flax_state_dict( - flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False -): - """Load pytorch checkpoints in a flax model""" - - if not is_sharded: - pt_path = os.path.abspath(pytorch_checkpoint_path) - logger.info(f"Loading PyTorch weights from {pt_path}") - - if pt_path.endswith(".safetensors"): - pt_state_dict = {} - with safe_open(pt_path, framework="flax") as f: - for k in f.keys(): - pt_state_dict[k] = f.get_tensor(k) - else: - try: - import torch # noqa: F401 - except (ImportError, ModuleNotFoundError): - logger.error( - "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" - " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/index.html#installation for installation" - " instructions." - ) - raise - - check_torch_load_is_safe() - pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) - logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") - - flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) - else: - # model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files - flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model) - return flax_state_dict - - -def rename_key_and_reshape_tensor( - pt_tuple_key: tuple[str], - pt_tensor: np.ndarray, - random_flax_state_dict: dict[str, jnp.ndarray], - model_prefix: str, -) -> tuple[tuple[str], np.ndarray]: - """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" - - def is_key_or_prefix_key_in_dict(key: tuple[str]) -> bool: - """Checks if `key` of `(prefix,) + key` is in random_flax_state_dict""" - return len(set(random_flax_state_dict) & {key, (model_prefix,) + key}) > 0 - - # layer norm - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) - if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): - return renamed_pt_tuple_key, pt_tensor - - # batch norm layer mean - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",) - if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key): - return renamed_pt_tuple_key, pt_tensor - - # batch norm layer var - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",) - if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key): - return renamed_pt_tuple_key, pt_tensor - - # embedding - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) - if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): - return renamed_pt_tuple_key, pt_tensor - - # conv layer - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) - if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key): - pt_tensor = pt_tensor.transpose(2, 3, 1, 0) - return renamed_pt_tuple_key, pt_tensor - - # linear layer - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) - if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key): - pt_tensor = pt_tensor.T - return renamed_pt_tuple_key, pt_tensor - - # old PyTorch layer norm weight - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) - if pt_tuple_key[-1] == "gamma": - return renamed_pt_tuple_key, pt_tensor - - # old PyTorch layer norm bias - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) - if pt_tuple_key[-1] == "beta": - return renamed_pt_tuple_key, pt_tensor - - # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 - name = None - if pt_tuple_key[-3::2] == ("parametrizations", "original0"): - name = pt_tuple_key[-2] + "_g" - elif pt_tuple_key[-3::2] == ("parametrizations", "original1"): - name = pt_tuple_key[-2] + "_v" - if name is not None: - renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,) - return renamed_pt_tuple_key, pt_tensor - - return pt_tuple_key, pt_tensor - - -def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): - # convert pytorch tensor to numpy - from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor) - bfloat16 = torch.bfloat16 if from_bin else "bfloat16" - - weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()} - - if from_bin: - for k, v in pt_state_dict.items(): - # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision - if v.dtype == bfloat16: - v = v.float() - pt_state_dict[k] = v.cpu().numpy() - - model_prefix = flax_model.base_model_prefix - - # use params dict if the model contains batch norm layers - if "params" in flax_model.params: - flax_model_params = flax_model.params["params"] - else: - flax_model_params = flax_model.params - random_flax_state_dict = flatten_dict(flax_model_params) - - # add batch_stats keys,values to dict - if "batch_stats" in flax_model.params: - flax_batch_stats = flatten_dict(flax_model.params["batch_stats"]) - random_flax_state_dict.update(flax_batch_stats) - - flax_state_dict = {} - - load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and ( - model_prefix in {k.split(".")[0] for k in pt_state_dict} - ) - load_base_model_into_model_with_head = (model_prefix in flax_model_params) and ( - model_prefix not in {k.split(".")[0] for k in pt_state_dict} - ) - - # Need to change some parameters name to match Flax names - for pt_key, pt_tensor in pt_state_dict.items(): - pt_tuple_key = tuple(pt_key.split(".")) - is_bfloat_16 = weight_dtypes[pt_key] == bfloat16 - - # remove base model prefix if necessary - has_base_model_prefix = pt_tuple_key[0] == model_prefix - if load_model_with_head_into_base_model and has_base_model_prefix: - pt_tuple_key = pt_tuple_key[1:] - - # Correctly rename weight parameters - flax_key, flax_tensor = rename_key_and_reshape_tensor( - pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix - ) - - # add model prefix if necessary - require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict - if load_base_model_into_model_with_head and require_base_model_prefix: - flax_key = (model_prefix,) + flax_key - - if flax_key in random_flax_state_dict: - if flax_tensor.shape != random_flax_state_dict[flax_key].shape: - raise ValueError( - f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " - f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." - ) - - # add batch stats if the model contains batchnorm layers - if "batch_stats" in flax_model.params: - if "mean" in flax_key[-1] or "var" in flax_key[-1]: - flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor) - continue - # remove num_batches_tracked key - if "num_batches_tracked" in flax_key[-1]: - flax_state_dict.pop(flax_key, None) - continue - - # also add unexpected weight so that warning is thrown - flax_state_dict[("params",) + flax_key] = ( - jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) - ) - else: - # also add unexpected weight so that warning is thrown - flax_state_dict[flax_key] = ( - jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) - ) - - return unflatten_dict(flax_state_dict) - - -############################ -# Sharded Pytorch => Flax # -############################ - - -def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): - import torch - - # Load the index - flax_state_dict = {} - for shard_file in shard_filenames: - # load using msgpack utils - check_torch_load_is_safe() - pt_state_dict = torch.load(shard_file, weights_only=True) - weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()} - pt_state_dict = { - k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items() - } - - model_prefix = flax_model.base_model_prefix - - # use params dict if the model contains batch norm layers and then add batch_stats keys,values to dict - if "batch_stats" in flax_model.params: - flax_model_params = flax_model.params["params"] - - random_flax_state_dict = flatten_dict(flax_model_params) - random_flax_state_dict.update(flatten_dict(flax_model.params["batch_stats"])) - else: - flax_model_params = flax_model.params - random_flax_state_dict = flatten_dict(flax_model_params) - - load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and ( - model_prefix in {k.split(".")[0] for k in pt_state_dict} - ) - load_base_model_into_model_with_head = (model_prefix in flax_model_params) and ( - model_prefix not in {k.split(".")[0] for k in pt_state_dict} - ) - # Need to change some parameters name to match Flax names - for pt_key, pt_tensor in pt_state_dict.items(): - pt_tuple_key = tuple(pt_key.split(".")) - is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16 - - # remove base model prefix if necessary - has_base_model_prefix = pt_tuple_key[0] == model_prefix - if load_model_with_head_into_base_model and has_base_model_prefix: - pt_tuple_key = pt_tuple_key[1:] - - # Correctly rename weight parameters - flax_key, flax_tensor = rename_key_and_reshape_tensor( - pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix - ) - # add model prefix if necessary - require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict - if load_base_model_into_model_with_head and require_base_model_prefix: - flax_key = (model_prefix,) + flax_key - - if flax_key in random_flax_state_dict: - if flax_tensor.shape != random_flax_state_dict[flax_key].shape: - raise ValueError( - f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " - f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." - ) - - # add batch stats if the model contains batchnorm layers - if "batch_stats" in flax_model.params: - if "mean" in flax_key[-1]: - flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor) - continue - if "var" in flax_key[-1]: - flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor) - continue - # remove num_batches_tracked key - if "num_batches_tracked" in flax_key[-1]: - flax_state_dict.pop(flax_key, None) - continue - - # also add unexpected weight so that warning is thrown - flax_state_dict[("params",) + flax_key] = ( - jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) - ) - - else: - # also add unexpected weight so that warning is thrown - flax_state_dict[flax_key] = ( - jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) - ) - return unflatten_dict(flax_state_dict) - - -##################### -# Flax => PyTorch # -##################### - - -def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path): - """Load flax checkpoints in a PyTorch model""" - flax_checkpoint_path = os.path.abspath(flax_checkpoint_path) - logger.info(f"Loading Flax weights from {flax_checkpoint_path}") - - # import correct flax class - flax_cls = getattr(transformers, "Flax" + model.__class__.__name__) - - # load flax weight dict - if flax_checkpoint_path.endswith(".safetensors"): - flax_state_dict = safe_load_file(flax_checkpoint_path) - flax_state_dict = unflatten_dict(flax_state_dict, sep=".") - else: - with open(flax_checkpoint_path, "rb") as state_f: - try: - flax_state_dict = from_bytes(flax_cls, state_f.read()) - except UnpicklingError: - raise OSError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ") - - return load_flax_weights_in_pytorch_model(model, flax_state_dict) - - -def load_flax_weights_in_pytorch_model(pt_model, flax_state): - """Load flax checkpoints in a PyTorch model""" - - try: - import torch # noqa: F401 - except (ImportError, ModuleNotFoundError): - logger.error( - "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see" - " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/index.html#installation for installation" - " instructions." - ) - raise - - # check if we have bf16 weights - is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() - if any(is_type_bf16): - # convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16 - # and bf16 is not fully supported in PT yet. - logger.warning( - "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " - "before loading those in PyTorch model." - ) - flax_state = jax.tree_util.tree_map( - lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state - ) - - flax_state_dict = flatten_dict(flax_state) - pt_model_dict = pt_model.state_dict() - - load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and ( - pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict} - ) - load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and ( - pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict} - ) - - # keep track of unexpected & missing keys - unexpected_keys = [] - missing_keys = set(pt_model_dict.keys()) - - for flax_key_tuple, flax_tensor in flax_state_dict.items(): - has_base_model_prefix = flax_key_tuple[0] == pt_model.base_model_prefix - require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict - - # adapt flax_key to prepare for loading from/to base model only - if load_model_with_head_into_base_model and has_base_model_prefix: - flax_key_tuple = flax_key_tuple[1:] - elif load_base_model_into_model_with_head and require_base_model_prefix: - flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple - - # rename flax weights to PyTorch format - if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 4 and ".".join(flax_key_tuple) not in pt_model_dict: - # conv layer - flax_key_tuple = flax_key_tuple[:-1] + ("weight",) - flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) - elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict: - # linear layer - flax_key_tuple = flax_key_tuple[:-1] + ("weight",) - flax_tensor = flax_tensor.T - elif flax_key_tuple[-1] in ["scale", "embedding"]: - flax_key_tuple = flax_key_tuple[:-1] + ("weight",) - - # adding batch stats from flax batch norm to pt - elif "mean" in flax_key_tuple[-1]: - flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",) - elif "var" in flax_key_tuple[-1]: - flax_key_tuple = flax_key_tuple[:-1] + ("running_var",) - - if "batch_stats" in flax_state: - flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header - else: - flax_key = ".".join(flax_key_tuple) - - # We also need to look at `pt_model_dict` and see if there are keys requiring further transformation. - special_pt_names = {} - # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 - for key in pt_model_dict: - key_components = key.split(".") - name = None - if key_components[-3::2] == ["parametrizations", "original0"]: - name = key_components[-2] + "_g" - elif key_components[-3::2] == ["parametrizations", "original1"]: - name = key_components[-2] + "_v" - if name is not None: - key_components = key_components[:-3] + [name] - key_to_check = ".".join(key_components) - special_pt_names[key_to_check] = key - - if flax_key in special_pt_names: - flax_key = special_pt_names[flax_key] - - if flax_key in pt_model_dict: - if flax_tensor.shape != pt_model_dict[flax_key].shape: - raise ValueError( - f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " - f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." - ) - else: - # add weight to pytorch dict - flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor - pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) - # remove from missing keys - missing_keys.remove(flax_key) - else: - # weight is not expected by PyTorch model - unexpected_keys.append(flax_key) - - pt_model.load_state_dict(pt_model_dict) - - # re-transform missing_keys to list - missing_keys = list(missing_keys) - - if len(unexpected_keys) > 0: - logger.warning( - "Some weights of the Flax model were not used when initializing the PyTorch model" - f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" - f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" - " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" - f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" - " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" - " FlaxBertForSequenceClassification model)." - ) - else: - logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n") - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" - f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" - " use it for predictions and inference." - ) - else: - logger.warning( - f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n" - "If your task is similar to the task the model of the checkpoint was trained on, " - f"you can already use {pt_model.__class__.__name__} for predictions without further training." - ) - - return pt_model diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py deleted file mode 100644 index bc9a4d473f36..000000000000 --- a/src/transformers/modeling_flax_utils.py +++ /dev/null @@ -1,1274 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import gc -import json -import os -import warnings -from functools import partial -from pickle import UnpicklingError -from typing import Any, Optional, Union - -import flax.linen as nn -import jax -import jax.numpy as jnp -import msgpack.exceptions -from flax.core.frozen_dict import FrozenDict, unfreeze -from flax.serialization import from_bytes, to_bytes -from flax.traverse_util import flatten_dict, unflatten_dict -from jax.random import PRNGKey - -from .configuration_utils import PretrainedConfig -from .dynamic_module_utils import custom_object_save -from .generation import FlaxGenerationMixin, GenerationConfig -from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict -from .utils import ( - FLAX_WEIGHTS_INDEX_NAME, - FLAX_WEIGHTS_NAME, - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - WEIGHTS_INDEX_NAME, - WEIGHTS_NAME, - PushToHubMixin, - add_code_sample_docstrings, - add_start_docstrings_to_model_forward, - cached_file, - copy_func, - download_url, - has_file, - is_offline_mode, - is_remote_url, - logging, - replace_return_docstrings, -) -from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files -from .utils.import_utils import is_safetensors_available - - -if is_safetensors_available(): - from safetensors import safe_open - from safetensors.flax import load_file as safe_load_file - from safetensors.flax import save_file as safe_save_file - -logger = logging.get_logger(__name__) - - -def quick_gelu(x): - return x * jax.nn.sigmoid(1.702 * x) - - -ACT2FN = { - "gelu": partial(nn.gelu, approximate=False), - "relu": nn.relu, - "silu": nn.swish, - "swish": nn.swish, - "gelu_new": partial(nn.gelu, approximate=True), - "quick_gelu": quick_gelu, - "gelu_pytorch_tanh": partial(nn.gelu, approximate=True), - "tanh": nn.tanh, -} - - -def flax_shard_checkpoint(params, max_shard_size="10GB"): - """ - Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a - given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so - there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For - example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as - [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. - - - - If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will - have a size greater than `max_shard_size`. - - - - Args: - params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. - max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): - The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit - (like `"5MB"`). - """ - max_shard_size = convert_file_size_to_int(max_shard_size) - - sharded_state_dicts = [] - current_block = {} - current_block_size = 0 - total_size = 0 - - # flatten the weights to chunk - weights = flatten_dict(params, sep="/") - for item in weights: - weight_size = weights[item].size * weights[item].dtype.itemsize - - # If this weight is going to tip up over the maximal size, we split. - if current_block_size + weight_size > max_shard_size: - sharded_state_dicts.append(current_block) - current_block = {} - current_block_size = 0 - - current_block[item] = weights[item] - current_block_size += weight_size - total_size += weight_size - - # Add the last block - sharded_state_dicts.append(current_block) - - # If we only have one shard, we return it - if len(sharded_state_dicts) == 1: - return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None - - # Otherwise, let's build the index - weight_map = {} - shards = {} - for idx, shard in enumerate(sharded_state_dicts): - shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.msgpack") - shards[shard_file] = shard - for weight_name in shard: - weight_map[weight_name] = shard_file - - # Add the metadata - metadata = {"total_size": total_size} - index = {"metadata": metadata, "weight_map": weight_map} - return shards, index - - -class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): - r""" - Base class for all models. - - [`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, - downloading and saving models. - - Class attributes (overridden by derived classes): - - - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class - for this model architecture. - - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived - classes of the same architecture adding modules on top of the base model. - - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP - models, `pixel_values` for vision models and `input_values` for speech models). - """ - - config_class = None - base_model_prefix = "" - main_input_name = "input_ids" - _auto_class = None - _missing_keys = set() - - def __init__( - self, - config: PretrainedConfig, - module: nn.Module, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - ): - logger.warning_once( - "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " - "recommend migrating to PyTorch classes or pinning your version of Transformers." - ) - if config is None: - raise ValueError("config cannot be None") - - if module is None: - raise ValueError("module cannot be None") - - # Those are private to be exposed as typed property on derived classes. - self._config = config - self._module = module - - # Those are public as their type is generic to every derived classes. - self.key = PRNGKey(seed) - self.dtype = dtype - self.input_shape = input_shape - self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None - - # To check if the model was initialized automatically. - self._is_initialized = _do_init - - if _do_init: - # randomly initialized parameters - random_params = self.init_weights(self.key, input_shape) - params_shape_tree = jax.eval_shape(lambda params: params, random_params) - else: - init_fn = partial(self.init_weights, input_shape=input_shape) - params_shape_tree = jax.eval_shape(init_fn, self.key) - - logger.info( - "Model weights are not initialized as `_do_init` is set to `False`. " - f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights." - ) - - # get the shape of the parameters - self._params_shape_tree = params_shape_tree - - # save required_params as set - self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) - - # initialize the parameters - if _do_init: - self.params = random_params - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> dict: - raise NotImplementedError(f"init method has to be implemented for {self}") - - def enable_gradient_checkpointing(self): - raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}") - - @classmethod - def _from_config(cls, config, **kwargs): - """ - All context managers that the model should be initialized under go here. - """ - return cls(config, **kwargs) - - @property - def framework(self) -> str: - """ - :str: Identifies that this is a Flax model. - """ - return "flax" - - @property - def config(self) -> PretrainedConfig: - return self._config - - @property - def module(self) -> nn.Module: - return self._module - - @property - def params(self) -> Union[dict, FrozenDict]: - if not self._is_initialized: - raise ValueError( - "`params` cannot be accessed from model when the model is created with `_do_init=False`. " - "You must call `init_weights` manually and store the params outside of the model and " - "pass it explicitly where needed." - ) - return self._params - - @property - def required_params(self) -> set: - return self._required_params - - @property - def params_shape_tree(self) -> dict: - return self._params_shape_tree - - @params.setter - def params(self, params: Union[dict, FrozenDict]): - # don't set params if the model is not initialized - if not self._is_initialized: - raise ValueError( - "`params` cannot be set from model when the model is created with `_do_init=False`. " - "You store the params outside of the model." - ) - - if isinstance(params, FrozenDict): - params = unfreeze(params) - param_keys = set(flatten_dict(params).keys()) - if len(self.required_params - param_keys) > 0: - raise ValueError( - "Some parameters are missing. Make sure that `params` include the following " - f"parameters {self.required_params - param_keys}" - ) - self._params = params - - def _cast_floating_to(self, params: Union[dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: - """ - Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. - """ - - # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 - def conditional_cast(param): - if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): - param = param.astype(dtype) - return param - - if mask is None: - return jax.tree_util.tree_map(conditional_cast, params) - - flat_params = flatten_dict(params) - flat_mask, _ = jax.tree_util.tree_flatten(mask) - - for masked, key in zip(flat_mask, sorted(flat_params.keys())): - if masked: - flat_params[key] = conditional_cast(flat_params[key]) - - return unflatten_dict(flat_params) - - def to_bf16(self, params: Union[dict, FrozenDict], mask: Any = None): - r""" - Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast - the `params` in place. - - This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full - half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. - - Arguments: - params (`Union[Dict, FrozenDict]`): - A `PyTree` of model parameters. - mask (`Union[Dict, FrozenDict]`): - A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params - you want to cast, and should be `False` for those you want to skip. - - Examples: - - ```python - >>> from transformers import FlaxBertModel - - >>> # load model - >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") - >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision - >>> model.params = model.to_bf16(model.params) - >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) - >>> # then pass the mask as follows - >>> from flax import traverse_util - - >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") - >>> flat_params = traverse_util.flatten_dict(model.params) - >>> mask = { - ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) - ... for path in flat_params - ... } - >>> mask = traverse_util.unflatten_dict(mask) - >>> model.params = model.to_bf16(model.params, mask) - ```""" - return self._cast_floating_to(params, jnp.bfloat16, mask) - - def to_fp32(self, params: Union[dict, FrozenDict], mask: Any = None): - r""" - Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the - model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. - - Arguments: - params (`Union[Dict, FrozenDict]`): - A `PyTree` of model parameters. - mask (`Union[Dict, FrozenDict]`): - A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params - you want to cast, and should be `False` for those you want to skip - - Examples: - - ```python - >>> from transformers import FlaxBertModel - - >>> # Download model and configuration from huggingface.co - >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") - >>> # By default, the model params will be in fp32, to illustrate the use of this method, - >>> # we'll first cast to fp16 and back to fp32 - >>> model.params = model.to_f16(model.params) - >>> # now cast back to fp32 - >>> model.params = model.to_fp32(model.params) - ```""" - return self._cast_floating_to(params, jnp.float32, mask) - - def to_fp16(self, params: Union[dict, FrozenDict], mask: Any = None): - r""" - Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the - `params` in place. - - This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full - half-precision training or to save weights in float16 for inference in order to save memory and improve speed. - - Arguments: - params (`Union[Dict, FrozenDict]`): - A `PyTree` of model parameters. - mask (`Union[Dict, FrozenDict]`): - A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params - you want to cast, and should be `False` for those you want to skip - - Examples: - - ```python - >>> from transformers import FlaxBertModel - - >>> # load model - >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") - >>> # By default, the model params will be in fp32, to cast these to float16 - >>> model.params = model.to_fp16(model.params) - >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) - >>> # then pass the mask as follows - >>> from flax import traverse_util - - >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") - >>> flat_params = traverse_util.flatten_dict(model.params) - >>> mask = { - ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) - ... for path in flat_params - ... } - >>> mask = traverse_util.unflatten_dict(mask) - >>> model.params = model.to_fp16(model.params, mask) - ```""" - return self._cast_floating_to(params, jnp.float16, mask) - - @classmethod - def load_flax_weights(cls, resolved_archive_file): - try: - if resolved_archive_file.endswith(".safetensors"): - state = safe_load_file(resolved_archive_file) - state = unflatten_dict(state, sep=".") - else: - with open(resolved_archive_file, "rb") as state_f: - state = from_bytes(cls, state_f.read()) - except (UnpicklingError, msgpack.exceptions.ExtraData) as e: - try: - with open(resolved_archive_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please" - " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" - " folder you cloned." - ) - else: - raise ValueError from e - except (UnicodeDecodeError, ValueError): - raise OSError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ") - - return state - - @classmethod - def load_flax_sharded_weights(cls, shard_files): - """ - This is the same as [`flax.serialization.from_bytes`] - (https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint. - - This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being - loaded in the model. - - Args: - shard_files (`list[str]`: - The list of shard files to load. - - Returns: - `Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model': - {'params': {'...'}}}`. - """ - - # Load the index - state_sharded_dict = {} - - for shard_file in shard_files: - # load using msgpack utils - try: - with open(shard_file, "rb") as state_f: - state = from_bytes(cls, state_f.read()) - except (UnpicklingError, msgpack.exceptions.ExtraData) as e: - with open(shard_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please" - " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" - " folder you cloned." - ) - else: - raise ValueError from e - except (UnicodeDecodeError, ValueError): - raise OSError(f"Unable to convert {shard_file} to Flax deserializable object. ") - - state = flatten_dict(state, sep="/") - state_sharded_dict.update(state) - del state - gc.collect() - - # the state dict is unflattened to the match the format of model.params - return unflatten_dict(state_sharded_dict, sep="/") - - @classmethod - def can_generate(cls) -> bool: - """ - Returns whether this model can generate sequences with `.generate()`. Returns: - `bool`: Whether this model can generate sequences with `.generate()`. - """ - # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. - # Alternatively, the model can also have a custom `generate` function. - if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): - return False - return True - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Union[str, os.PathLike], - dtype: jnp.dtype = jnp.float32, - *model_args, - config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, - cache_dir: Optional[Union[str, os.PathLike]] = None, - ignore_mismatched_sizes: bool = False, - force_download: bool = False, - local_files_only: bool = False, - token: Optional[Union[str, bool]] = None, - revision: str = "main", - **kwargs, - ): - r""" - Instantiate a pretrained flax model from a pre-trained model configuration. - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, - `from_pt` should be set to `True`. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. - model_args (sequence of positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): - Can be either: - - - an instance of a class derived from [`PretrainedConfig`], - - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. - - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - The model is a model provided by the library (loaded with the *model id* string of a pretrained - model). - - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the - save directory. - - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a - configuration JSON file named *config.json* is found in the directory. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - from_pt (`bool`, *optional*, defaults to `False`): - Load the model weights from a PyTorch checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). - ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): - Whether or not to raise an error if some of the weights from the checkpoint do not have the same size - as the weights of the model (if for instance, you are instantiating a model with 10 labels from a - checkpoint with 3 labels). - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download: - Deprecated and ignored. All downloads are now resumed by default when possible. - Will be removed in v5 of Transformers. - proxies (`dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - token (`str` or `bool`, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use - the token generated when running `hf auth login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - - - - - To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. - - - - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can - specify the folder name here. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). Behaves differently depending on whether a `config` is provided or - automatically loaded: - - - If a configuration is provided with `config`, `**kwargs` will be directly passed to the - underlying model's `__init__` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that - corresponds to a configuration attribute will be used to override said attribute with the - supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute - will be passed to the underlying model's `__init__` function. - - Examples: - - ```python - >>> from transformers import BertConfig, FlaxBertModel - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") - >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). - >>> model = FlaxBertModel.from_pretrained("./test/saved_model/") - >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). - >>> config = BertConfig.from_json_file("./pt_model/config.json") - >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config) - ```""" - from_pt = kwargs.pop("from_pt", False) - resume_download = kwargs.pop("resume_download", None) - proxies = kwargs.pop("proxies", None) - use_auth_token = kwargs.pop("use_auth_token", None) - trust_remote_code = kwargs.pop("trust_remote_code", None) - from_pipeline = kwargs.pop("_from_pipeline", None) - from_auto_class = kwargs.pop("_from_auto", False) - _do_init = kwargs.pop("_do_init", True) - subfolder = kwargs.pop("subfolder", "") - commit_hash = kwargs.pop("_commit_hash", None) - - # Not relevant for Flax Models - _ = kwargs.pop("adapter_kwargs", None) - - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if token is not None: - raise ValueError( - "`token` and `use_auth_token` are both specified. Please set only the argument `token`." - ) - token = use_auth_token - - if trust_remote_code is True: - logger.warning( - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" - " ignored." - ) - - user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} - if from_pipeline is not None: - user_agent["using_pipeline"] = from_pipeline - - if is_offline_mode() and not local_files_only: - logger.info("Offline mode: forcing local_files_only=True") - local_files_only = True - - # Load config if we don't provide a configuration - if not isinstance(config, PretrainedConfig): - config_path = config if config is not None else pretrained_model_name_or_path - config, model_kwargs = cls.config_class.from_pretrained( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - _from_auto=from_auto_class, - _from_pipeline=from_pipeline, - _commit_hash=commit_hash, - **kwargs, - ) - else: - model_kwargs = kwargs.copy() - - if commit_hash is None: - commit_hash = getattr(config, "_commit_hash", None) - - # Add the dtype to model_kwargs - model_kwargs["dtype"] = dtype - - # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the - # index of the files. - is_sharded = False - - # Load model - if pretrained_model_name_or_path is not None: - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - is_local = os.path.isdir(pretrained_model_name_or_path) - if is_local: - if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): - # Load from a Flax checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)): - # Load from a sharded Flax checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME) - is_sharded = True - elif is_safetensors_available() and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME) - ): - # Load from a safetensors checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME) - elif is_safetensors_available() and os.path.isfile( - os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) - ): - # Load from a safetensors checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) - elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): - # Load from a PyTorch checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) - elif from_pt and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) - ): - # Load from a sharded pytorch checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) - is_sharded = True - # At this stage we don't have a weight file so we will raise an error. - elif is_safetensors_available() and os.path.isfile( - os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) - ): - # Load from a sharded safetensors checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) - is_sharded = True - raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!") - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): - raise OSError( - f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " - "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " - "weights." - ) - else: - raise OSError( - f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " - f"{pretrained_model_name_or_path}." - ) - elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): - archive_file = pretrained_model_name_or_path - is_local = True - elif is_remote_url(pretrained_model_name_or_path): - filename = pretrained_model_name_or_path - resolved_archive_file = download_url(pretrained_model_name_or_path) - else: - if from_pt: - filename = WEIGHTS_NAME - else: - filename = FLAX_WEIGHTS_NAME - - try: - # Load from URL or cache if already cached - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "resume_download": resume_download, - "local_files_only": local_files_only, - "token": token, - "user_agent": user_agent, - "revision": revision, - "subfolder": subfolder, - "_raise_exceptions_for_gated_repo": False, - "_raise_exceptions_for_missing_entries": False, - "_commit_hash": commit_hash, - } - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME: - resolved_archive_file = cached_file( - pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs - ) - if resolved_archive_file is not None: - is_sharded = True - - # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. - if resolved_archive_file is None and from_pt: - resolved_archive_file = cached_file( - pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs - ) - if resolved_archive_file is not None: - is_sharded = True - - # If we still haven't found anything, look for `safetensors`. - if resolved_archive_file is None: - # No support for sharded safetensors yet, so we'll raise an error if that's all we find. - filename = SAFE_WEIGHTS_NAME - resolved_archive_file = cached_file( - pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs - ) - - # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None - # result when internet is up, the repo and revision exist, but the file does not. - if resolved_archive_file is None: - # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error - # message. - has_file_kwargs = { - "revision": revision, - "proxies": proxies, - "token": token, - "cache_dir": cache_dir, - "local_files_only": local_files_only, - } - if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): - is_sharded = True - raise NotImplementedError( - "Support for sharded checkpoints using safetensors is coming soon!" - ) - elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): - raise OSError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" - " load this model from those weights." - ) - elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs): - raise OSError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use" - " `from_pt=True` to load this model from those weights." - ) - else: - raise OSError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." - ) - except OSError: - # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted - # to the original exception. - raise - except Exception: - # For any other exception, we throw a generic error. - raise OSError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" - " from 'https://huggingface.co/models', make sure you don't have a local directory with the" - f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." - ) - - if is_local: - logger.info(f"loading weights file {archive_file}") - resolved_archive_file = archive_file - filename = resolved_archive_file.split(os.path.sep)[-1] - else: - logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") - else: - resolved_archive_file = None - - # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. - if is_sharded: - # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. - resolved_archive_file, _ = get_checkpoint_shard_files( - pretrained_model_name_or_path, - resolved_archive_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder, - _commit_hash=commit_hash, - ) - - safetensors_from_pt = False - if filename == SAFE_WEIGHTS_NAME: - with safe_open(resolved_archive_file, framework="flax") as f: - safetensors_metadata = f.metadata() - if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]: - raise OSError( - f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." - " Make sure you save your model with the `save_pretrained` method." - ) - safetensors_from_pt = safetensors_metadata.get("format") == "pt" - - # init random models - model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) - - if from_pt or safetensors_from_pt: - state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) - else: - if is_sharded: - state = cls.load_flax_sharded_weights(resolved_archive_file) - else: - state = cls.load_flax_weights(resolved_archive_file) - # make sure all arrays are stored as jnp.arrays - # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: - # https://github.com/google/flax/issues/1261 - if _do_init: - state = jax.tree_util.tree_map(jnp.array, state) - else: - # keep the params on CPU if we don't want to initialize - state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state) - - if "batch_stats" in state: # if flax model contains batch norm layers - # if model is base model only use model_prefix key - if ( - cls.base_model_prefix not in dict(model.params_shape_tree["params"]) - and cls.base_model_prefix in state["params"] - ): - state["params"] = state["params"][cls.base_model_prefix] - state["batch_stats"] = state["batch_stats"][cls.base_model_prefix] - - # if model is head model and we are loading weights from base model - # we initialize new params dict with base_model_prefix - if ( - cls.base_model_prefix in dict(model.params_shape_tree["params"]) - and cls.base_model_prefix not in state["params"] - ): - state = { - "params": {cls.base_model_prefix: state["params"]}, - "batch_stats": {cls.base_model_prefix: state["batch_stats"]}, - } - - else: - # if model is base model only use model_prefix key - if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state: - state = state[cls.base_model_prefix] - - # if model is head model and we are loading weights from base model - # we initialize new params dict with base_model_prefix - if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state: - state = {cls.base_model_prefix: state} - - # flatten dicts - state = flatten_dict(state) - - random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree)) - - missing_keys = model.required_params - set(state.keys()) - unexpected_keys = set(state.keys()) - model.required_params - - # Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked - for unexpected_key in unexpected_keys.copy(): - if "num_batches_tracked" in unexpected_key[-1]: - unexpected_keys.remove(unexpected_key) - - if missing_keys and not _do_init: - logger.warning( - f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " - "Make sure to call model.init_weights to initialize the missing weights." - ) - cls._missing_keys = missing_keys - - # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not - # matching the weights in the model. - mismatched_keys = [] - for key in state: - if key in random_state and state[key].shape != random_state[key].shape: - if ignore_mismatched_sizes: - mismatched_keys.append((key, state[key].shape, random_state[key].shape)) - state[key] = random_state[key] - else: - raise ValueError( - f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " - f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " - "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " - "model." - ) - - # add missing keys as random parameters if we are initializing - if missing_keys and _do_init: - for missing_key in missing_keys: - state[missing_key] = random_state[missing_key] - - # remove unexpected keys to not be saved again - for unexpected_key in unexpected_keys: - del state[unexpected_key] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" - f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" - f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" - " with another architecture (e.g. initializing a BertForSequenceClassification model from a" - " BertForPreTraining model).\n- This IS NOT expected if you are initializing" - f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" - " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." - ) - else: - logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") - - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" - " TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - elif len(mismatched_keys) == 0: - logger.info( - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" - f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" - f" was trained on, you can already use {model.__class__.__name__} for predictions without further" - " training." - ) - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys - ] - ) - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" - f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" - " to use it for predictions and inference." - ) - - # dictionary of key: dtypes for the model params - param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state) - # extract keys of parameters not in jnp.float32 - fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16] - bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16] - - # raise a warning if any of the parameters are not in jnp.float32 - if len(fp16_params) > 0: - logger.warning( - f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " - f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" - "You should probably UPCAST the model weights to float32 if this was not intended. " - "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." - ) - - if len(bf16_params) > 0: - logger.warning( - f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " - f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" - "You should probably UPCAST the model weights to float32 if this was not intended. " - "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." - ) - - # If it is a model with generation capabilities, attempt to load the generation config - if model.can_generate(): - try: - model.generation_config = GenerationConfig.from_pretrained( - pretrained_model_name_or_path, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - _from_auto=from_auto_class, - _from_pipeline=from_pipeline, - **kwargs, - ) - except OSError: - logger.info( - "Generation config file not found, using a generation config created from the model config." - ) - pass - - if _do_init: - # set correct parameters - model.params = unflatten_dict(state) - return model - else: - return model, unflatten_dict(state) - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - params=None, - push_to_hub=False, - max_shard_size="10GB", - token: Optional[Union[str, bool]] = None, - safe_serialization: bool = False, - **kwargs, - ): - """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~FlaxPreTrainedModel.from_pretrained`]` class method - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the - repository you want to push to with `repo_id` (will default to the name of `save_directory` in your - namespace). - max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): - The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size - lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). - - - - If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard - which will be bigger than `max_shard_size`. - - - - token (`str` or `bool`, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use - the token generated when running `hf auth login` (stored in `~/.huggingface`). - kwargs (`dict[str, Any]`, *optional*): - Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. - safe_serialization (`bool`, *optional*, defaults to `False`): - Whether to save the model using `safetensors` or through msgpack. - """ - use_auth_token = kwargs.pop("use_auth_token", None) - - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if token is not None: - raise ValueError( - "`token` and `use_auth_token` are both specified. Please set only the argument `token`." - ) - token = use_auth_token - - if token is not None: - kwargs["token"] = token - - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - os.makedirs(save_directory, exist_ok=True) - - if push_to_hub: - commit_message = kwargs.pop("commit_message", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id = self._create_repo(repo_id, **kwargs) - files_timestamps = self._get_files_timestamps(save_directory) - - # get abs dir - save_directory = os.path.abspath(save_directory) - # save config as well - self.config.architectures = [self.__class__.__name__[4:]] - - # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be - # loaded from the Hub. - if self._auto_class is not None: - custom_object_save(self, save_directory, config=self.config) - - self.config.save_pretrained(save_directory) - if self.can_generate(): - self.generation_config.save_pretrained(save_directory) - - # save model - weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME - output_model_file = os.path.join(save_directory, weights_name) - - shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size) - # Clean the folder from a previous save - for filename in os.listdir(save_directory): - full_filename = os.path.join(save_directory, filename) - weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") - if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards: - os.remove(full_filename) - - if index is None: - if safe_serialization: - params = params if params is not None else self.params - flat_dict = flatten_dict(params, sep=".") - safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"}) - else: - with open(output_model_file, "wb") as f: - params = params if params is not None else self.params - model_bytes = to_bytes(params) - f.write(model_bytes) - - else: - save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME) - # Save the index as well - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) - logger.info( - f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " - f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) - for shard_file, shard in shards.items(): - # the shard item are unflattened, to save them we need to flatten them again - with open(os.path.join(save_directory, shard_file), mode="wb") as f: - params = unflatten_dict(shard, sep="/") - shard_bytes = to_bytes(params) - f.write(shard_bytes) - - logger.info(f"Model weights saved in {output_model_file}") - - if push_to_hub: - self._upload_modified_files( - save_directory, - repo_id, - files_timestamps, - commit_message=commit_message, - token=token, - ) - - @classmethod - def register_for_auto_class(cls, auto_class="FlaxAutoModel"): - """ - Register this class with a given auto class. This should only be used for custom models as the ones in the - library are already mapped with an auto class. - - - - Args: - auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`): - The auto class to register this new model with. - """ - if not isinstance(auto_class, str): - auto_class = auto_class.__name__ - - import transformers.models.auto as auto_module - - if not hasattr(auto_module, auto_class): - raise ValueError(f"{auto_class} is not a valid auto class.") - - cls._auto_class = auto_class - - -# To update the docstring, we need to copy the method, otherwise we change the original docstring. -FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub) -if FlaxPreTrainedModel.push_to_hub.__doc__ is not None: - FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format( - object="model", object_class="FlaxAutoModel", object_files="model checkpoint" - ) - - -def overwrite_call_docstring(model_class, docstring): - # copy __call__ function to be sure docstring is changed only for this function - model_class.__call__ = copy_func(model_class.__call__) - # delete existing docstring - model_class.__call__.__doc__ = None - # set correct docstring - model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__) - - -def append_call_sample_docstring( - model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None -): - model_class.__call__ = copy_func(model_class.__call__) - model_class.__call__ = add_code_sample_docstrings( - checkpoint=checkpoint, - output_type=output_type, - config_class=config_class, - model_cls=model_class.__name__, - revision=revision, - real_checkpoint=real_checkpoint, - )(model_class.__call__) - - -def append_replace_return_docstrings(model_class, output_type, config_class): - model_class.__call__ = copy_func(model_class.__call__) - model_class.__call__ = replace_return_docstrings( - output_type=output_type, - config_class=config_class, - )(model_class.__call__) diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py deleted file mode 100644 index c7491b67f9ae..000000000000 --- a/src/transformers/modeling_tf_outputs.py +++ /dev/null @@ -1,990 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import warnings -from dataclasses import dataclass - -import tensorflow as tf - -from .utils import ModelOutput - - -@dataclass -class TFBaseModelOutput(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFBaseModelOutputWithNoAttention(ModelOutput): - """ - Base class for model's outputs, with potential hidden states. - - Args: - last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each layer) of shape `(batch_size, num_channels, height, width)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - last_hidden_state: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFBaseModelOutputWithPooling(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) further processed by a - Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence - prediction (classification) objective during pretraining. - - This output is usually *not* a good summary of the semantic content of the input, you're often better with - averaging or pooling the sequence of hidden-states for the whole input sequence. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: tf.Tensor | None = None - pooler_output: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state after a pooling operation on the spatial dimensions. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each layer) of shape `(batch_size, num_channels, height, width)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - last_hidden_state: tf.Tensor | None = None - pooler_output: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) further processed by a - Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence - prediction (classification) objective during pretraining. - - This output is usually *not* a good summary of the semantic content of the input, you're often better with - averaging or pooling the sequence of hidden-states for the whole input sequence. - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - """ - - last_hidden_state: tf.Tensor | None = None - pooler_output: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - cross_attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFBaseModelOutputWithPast(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFBaseModelOutputWithCrossAttentions(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - """ - - last_hidden_state: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - cross_attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - """ - - last_hidden_state: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - cross_attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFSeq2SeqModelOutput(ModelOutput): - """ - Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential - decoding. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - last_hidden_state: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - decoder_hidden_states: tuple[tf.Tensor] | None = None - decoder_attentions: tuple[tf.Tensor] | None = None - cross_attentions: tuple[tf.Tensor] | None = None - encoder_last_hidden_state: tf.Tensor | None = None - encoder_hidden_states: tuple[tf.Tensor] | None = None - encoder_attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFCausalLMOutput(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFCausalLMOutputWithPast(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFCausalLMOutputWithCrossAttentions(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - cross_attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFMaskedLMOutput(ModelOutput): - """ - Base class for masked language models outputs. - - Args: - loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): - Masked language modeling (MLM) loss. - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFSeq2SeqLMOutput(ModelOutput): - """ - Base class for sequence-to-sequence language models outputs. - - Args: - loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): - Language modeling loss. - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - decoder_hidden_states: tuple[tf.Tensor] | None = None - decoder_attentions: tuple[tf.Tensor] | None = None - cross_attentions: tuple[tf.Tensor] | None = None - encoder_last_hidden_state: tf.Tensor | None = None - encoder_hidden_states: tuple[tf.Tensor] | None = None - encoder_attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFNextSentencePredictorOutput(ModelOutput): - """ - Base class for outputs of models predicting if two sentences are consecutive or not. - - Args: - loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `next_sentence_label` is provided): - Next sentence prediction loss. - logits (`tf.Tensor` of shape `(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFSequenceClassifierOutput(ModelOutput): - """ - Base class for outputs of sentence classification models. - - Args: - loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFSeq2SeqSequenceClassifierOutput(ModelOutput): - """ - Base class for outputs of sequence-to-sequence sentence classification models. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `label` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)` - encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - decoder_hidden_states: tuple[tf.Tensor] | None = None - decoder_attentions: tuple[tf.Tensor] | None = None - cross_attentions: tuple[tf.Tensor] | None = None - encoder_last_hidden_state: tf.Tensor | None = None - encoder_hidden_states: tuple[tf.Tensor] | None = None - encoder_attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFSemanticSegmenterOutput(ModelOutput): - """ - Base class for outputs of semantic segmentation models. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): - Classification scores for each pixel. - - - - The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is - to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the - original image size as post-processing. You should always check your logits shape and resize as needed. - - - - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFSemanticSegmenterOutputWithNoAttention(ModelOutput): - """ - Base class for outputs of semantic segmentation models that do not output attention scores. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): - Classification scores for each pixel. - - - - The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is - to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the - original image size as post-processing. You should always check your logits shape and resize as needed. - - - - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - - -@dataclass -class TFImageClassifierOutput(ModelOutput): - """ - Base class for outputs of image classification models. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called - feature maps) of the model at the output of each stage. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFMultipleChoiceModelOutput(ModelOutput): - """ - Base class for outputs of multiple choice models. - - Args: - loss (`tf.Tensor` of shape *(batch_size, )*, *optional*, returned when `labels` is provided): - Classification loss. - logits (`tf.Tensor` of shape `(batch_size, num_choices)`): - *num_choices* is the second dimension of the input tensors. (see *input_ids* above). - - Classification scores (before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFTokenClassifierOutput(ModelOutput): - """ - Base class for outputs of token classification models. - - Args: - loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of unmasked labels, returned when `labels` is provided) : - Classification loss. - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`): - Classification scores (before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFQuestionAnsweringModelOutput(ModelOutput): - """ - Base class for outputs of question answering models. - - Args: - loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `start_positions` and `end_positions` are provided): - Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. - start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Span-start scores (before SoftMax). - end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Span-end scores (before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - start_logits: tf.Tensor | None = None - end_logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput): - """ - Base class for outputs of sequence-to-sequence question answering models. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. - start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Span-start scores (before SoftMax). - end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Span-end scores (before SoftMax). - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: tf.Tensor | None = None - start_logits: tf.Tensor | None = None - end_logits: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - decoder_hidden_states: tuple[tf.Tensor] | None = None - decoder_attentions: tuple[tf.Tensor] | None = None - encoder_last_hidden_state: tf.Tensor | None = None - encoder_hidden_states: tuple[tf.Tensor] | None = None - encoder_attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFSequenceClassifierOutputWithPast(ModelOutput): - """ - Base class for outputs of sentence classification models. - - Args: - loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFImageClassifierOutputWithNoAttention(ModelOutput): - """ - Base class for outputs of image classification models. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called - feature maps) of the model at the output of each stage. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFMaskedImageModelingOutput(ModelOutput): - """ - Base class for outputs of masked image completion / in-painting models. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): - Reconstruction loss. - reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Reconstructed / completed images. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when - `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called - feature maps) of the model at the output of each stage. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when - `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - reconstruction: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - @property - def logits(self): - warnings.warn( - "logits attribute is deprecated and will be removed in version 5 of Transformers." - " Please use the reconstruction attribute to retrieve the final output instead.", - FutureWarning, - ) - return self.reconstruction diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py deleted file mode 100644 index 8f688af7be36..000000000000 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ /dev/null @@ -1,676 +0,0 @@ -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch - TF 2.0 general utilities.""" - -import os -import re - -import numpy - -from .utils import ( - ExplicitEnum, - check_torch_load_is_safe, - expand_dims, - is_numpy_array, - is_safetensors_available, - is_torch_tensor, - logging, - reshape, - squeeze, - tensor_size, -) -from .utils import transpose as transpose_func - - -if is_safetensors_available(): - from safetensors import safe_open - - -logger = logging.get_logger(__name__) - - -class TransposeType(ExplicitEnum): - """ - Possible ... - """ - - NO = "no" - SIMPLE = "simple" - CONV1D = "conv1d" - CONV2D = "conv2d" - - -def convert_tf_weight_name_to_pt_weight_name( - tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None -): - """ - Convert a TF 2.0 model variable name in a pytorch model weight name. - - Conventions for TF2.0 scopes -> PyTorch attribute names conversions: - - - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) - - '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) - - return tuple with: - - - pytorch model weight name - - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be - transposed with regards to each other - """ - if name_scope is not None: - if not tf_name.startswith(name_scope) and "final_logits_bias" not in tf_name: - raise ValueError( - f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error " - "in Transformers, so (unless you were doing something really evil) please open an issue to report it!" - ) - tf_name = tf_name[len(name_scope) :] - tf_name = tf_name.lstrip("/") - tf_name = tf_name.replace(":0", "") # device ids - if (len(tf_name) > 2048 and "___" in tf_name) or tf_name.count("___") > 10: - # ReDOS check - raise ValueError("TF variable name is too long or contains too many ___ separators: " + tf_name) - tf_name = re.sub( - r"/[^/]*___([^/]*)/", r"/\1/", tf_name - ) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) - tf_name = tf_name.replace( - "_._", "/" - ) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) - tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end - tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators - # Some weights have a single name without "/" such as final_logits_bias in BART - if len(tf_name) > 1: - tf_name = tf_name[1:] # Remove level zero - - tf_weight_shape = list(tf_weight_shape) - - # When should we transpose the weights - if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4: - transpose = TransposeType.CONV2D - elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3: - transpose = TransposeType.CONV1D - elif bool( - tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"] - or "emb_projs" in tf_name - or "out_projs" in tf_name - ): - transpose = TransposeType.SIMPLE - else: - transpose = TransposeType.NO - - # Convert standard TF2.0 names in PyTorch names - if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma": - tf_name[-1] = "weight" - if tf_name[-1] == "beta": - tf_name[-1] = "bias" - - # The SeparableConv1D TF layer contains two weights that are translated to PyTorch Conv1D here - if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel": - tf_name[-1] = tf_name[-1].replace("_kernel", ".weight") - - # Remove prefix if needed - tf_name = ".".join(tf_name) - if start_prefix_to_remove: - tf_name = tf_name.replace(start_prefix_to_remove, "", 1) - - return tf_name, transpose - - -def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True): - """ - Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a - framework agnostic way. - """ - if transpose is TransposeType.CONV2D: - # Conv2D weight: - # PT: (num_out_channel, num_in_channel, kernel[0], kernel[1]) - # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel) - axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1) - weight = transpose_func(weight, axes=axes) - elif transpose is TransposeType.CONV1D: - # Conv1D weight: - # PT: (num_out_channel, num_in_channel, kernel) - # -> TF: (kernel, num_in_channel, num_out_channel) - weight = transpose_func(weight, axes=(2, 1, 0)) - elif transpose is TransposeType.SIMPLE: - weight = transpose_func(weight) - - if match_shape is None: - return weight - - if len(match_shape) < len(weight.shape): - weight = squeeze(weight) - elif len(match_shape) > len(weight.shape): - weight = expand_dims(weight, axis=0) - - if list(match_shape) != list(weight.shape): - try: - weight = reshape(weight, match_shape) - except AssertionError as e: - e.args += (match_shape, match_shape) - raise e - - return weight - - -##################### -# PyTorch => TF 2.0 # -##################### - - -def load_pytorch_checkpoint_in_tf2_model( - tf_model, - pytorch_checkpoint_path, - tf_inputs=None, - allow_missing_keys=False, - output_loading_info=False, - _prefix=None, - tf_to_pt_weight_rename=None, -): - """Load pytorch checkpoints in a TF 2.0 model""" - try: - import tensorflow as tf # noqa: F401 - import torch # noqa: F401 - from safetensors.torch import load_file as safe_load_file # noqa: F401 - except ImportError: - logger.error( - "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " - "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." - ) - raise - - # Treats a single file as a collection of shards with 1 shard. - if isinstance(pytorch_checkpoint_path, str): - pytorch_checkpoint_path = [pytorch_checkpoint_path] - - # Loads all shards into a single state dictionary - pt_state_dict = {} - for path in pytorch_checkpoint_path: - pt_path = os.path.abspath(path) - logger.info(f"Loading PyTorch weights from {pt_path}") - if pt_path.endswith(".safetensors"): - state_dict = safe_load_file(pt_path) - else: - check_torch_load_is_safe() - state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) - - pt_state_dict.update(state_dict) - - logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters") - - return load_pytorch_weights_in_tf2_model( - tf_model, - pt_state_dict, - tf_inputs=tf_inputs, - allow_missing_keys=allow_missing_keys, - output_loading_info=output_loading_info, - _prefix=_prefix, - tf_to_pt_weight_rename=tf_to_pt_weight_rename, - ) - - -def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False): - """Load pytorch checkpoints in a TF 2.0 model""" - pt_state_dict = pt_model.state_dict() - - return load_pytorch_weights_in_tf2_model( - tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys - ) - - -def load_pytorch_weights_in_tf2_model( - tf_model, - pt_state_dict, - tf_inputs=None, - allow_missing_keys=False, - output_loading_info=False, - _prefix=None, - tf_to_pt_weight_rename=None, -): - """Load pytorch state_dict in a TF 2.0 model.""" - try: - import tensorflow as tf # noqa: F401 - import torch # noqa: F401 - except ImportError: - logger.error( - "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " - "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." - ) - raise - - # Numpy doesn't understand bfloat16, so upcast to a dtype that doesn't lose precision - pt_state_dict = { - k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items() - } - return load_pytorch_state_dict_in_tf2_model( - tf_model, - pt_state_dict, - tf_inputs=tf_inputs, - allow_missing_keys=allow_missing_keys, - output_loading_info=output_loading_info, - _prefix=_prefix, - tf_to_pt_weight_rename=tf_to_pt_weight_rename, - ) - - -def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name): - if len(unexpected_keys) > 0: - logger.warning( - "Some weights of the PyTorch model were not used when initializing the TF 2.0 model" - f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing" - f" {class_name} from a PyTorch model trained on another task or with another architecture" - " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS" - f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect" - " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a" - " BertForSequenceClassification model)." - ) - else: - logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n") - if len(missing_keys) > 0: - logger.warning( - f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the" - f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a" - " down-stream task to be able to use it for predictions and inference." - ) - else: - logger.warning( - f"All the weights of {class_name} were initialized from the PyTorch model.\n" - "If your task is similar to the task the model of the checkpoint was trained on, " - f"you can already use {class_name} for predictions without further training." - ) - - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys - ] - ) - logger.warning( - f"Some weights of {class_name} were not initialized from the model checkpoint" - f" are newly initialized because the shapes did not" - f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" - " to use it for predictions and inference." - ) - - -def load_pytorch_state_dict_in_tf2_model( - tf_model, - pt_state_dict, - tf_inputs=None, - allow_missing_keys=False, - output_loading_info=False, - _prefix=None, - tf_to_pt_weight_rename=None, - ignore_mismatched_sizes=False, - skip_logger_warnings=False, -): - """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading - safetensors archive created with the safe_open() function.""" - import tensorflow as tf - - if tf_inputs is None: - tf_inputs = tf_model.dummy_inputs - - if _prefix is None: - _prefix = "" - if tf_inputs: - with tf.name_scope(_prefix): - tf_model(tf_inputs, training=False) # Make sure model is built - # Convert old format to new format if needed from a PyTorch state_dict - tf_keys_to_pt_keys = {} - for key in pt_state_dict: - new_key = None - if "gamma" in key: - new_key = key.replace("gamma", "weight") - if "beta" in key: - new_key = key.replace("beta", "bias") - if "running_var" in key: - new_key = key.replace("running_var", "moving_variance") - if "running_mean" in key: - new_key = key.replace("running_mean", "moving_mean") - - # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 - key_components = key.split(".") - name = None - if key_components[-3::2] == ["parametrizations", "original0"]: - name = key_components[-2] + "_g" - elif key_components[-3::2] == ["parametrizations", "original1"]: - name = key_components[-2] + "_v" - if name is not None: - key_components = key_components[:-3] + [name] - new_key = ".".join(key_components) - - if new_key is None: - new_key = key - tf_keys_to_pt_keys[new_key] = key - - # Matt: All TF models store the actual model stem in a MainLayer class, including the base model. - # In PT, the derived models (with heads) use the base model class as the stem instead, - # and there is no MainLayer class. This means that TF base classes have one - # extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that. - start_prefix_to_remove = "" - if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys): - start_prefix_to_remove = tf_model.base_model_prefix + "." - - symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights - tf_loaded_numel = 0 - all_pytorch_weights = set(tf_keys_to_pt_keys.keys()) - missing_keys = [] - mismatched_keys = [] - is_safetensor_archive = hasattr(pt_state_dict, "get_tensor") - for symbolic_weight in symbolic_weights: - sw_name = symbolic_weight.name - name, transpose = convert_tf_weight_name_to_pt_weight_name( - sw_name, - start_prefix_to_remove=start_prefix_to_remove, - tf_weight_shape=symbolic_weight.shape, - name_scope=_prefix, - ) - if tf_to_pt_weight_rename is not None: - aliases = tf_to_pt_weight_rename(name) # Is a tuple to account for possible name aliasing - for alias in aliases: # The aliases are in priority order, take the first one that matches - if alias in tf_keys_to_pt_keys: - name = alias - break - else: - # If none of the aliases match, just use the first one (it'll be reported as missing) - name = aliases[0] - - # Find associated numpy array in pytorch model state dict - if name not in tf_keys_to_pt_keys: - if allow_missing_keys: - missing_keys.append(name) - continue - elif tf_model._keys_to_ignore_on_load_missing is not None: - # authorized missing keys don't have to be loaded - if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing): - continue - raise AttributeError(f"{name} not found in PyTorch model") - state_dict_name = tf_keys_to_pt_keys[name] - if is_safetensor_archive: - array = pt_state_dict.get_tensor(state_dict_name) - else: - array = pt_state_dict[state_dict_name] - try: - array = apply_transpose(transpose, array, symbolic_weight.shape) - except tf.errors.InvalidArgumentError as e: - if not ignore_mismatched_sizes: - error_msg = str(e) - error_msg += ( - "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." - ) - raise tf.errors.InvalidArgumentError(error_msg) - else: - mismatched_keys.append((name, array.shape, symbolic_weight.shape)) - continue - - tf_loaded_numel += tensor_size(array) - - symbolic_weight.assign(tf.cast(array, symbolic_weight.dtype)) - del array # Immediately free memory to keep peak usage as low as possible - all_pytorch_weights.discard(name) - - logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.") - - unexpected_keys = list(all_pytorch_weights) - - if tf_model._keys_to_ignore_on_load_missing is not None: - for pat in tf_model._keys_to_ignore_on_load_missing: - missing_keys = [k for k in missing_keys if re.search(pat, k) is None] - if tf_model._keys_to_ignore_on_load_unexpected is not None: - for pat in tf_model._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - if not skip_logger_warnings: - _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__) - - if output_loading_info: - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - } - return tf_model, loading_info - - return tf_model - - -def load_sharded_pytorch_safetensors_in_tf2_model( - tf_model, - safetensors_shards, - tf_inputs=None, - allow_missing_keys=False, - output_loading_info=False, - _prefix=None, - tf_to_pt_weight_rename=None, - ignore_mismatched_sizes=False, -): - all_loading_infos = [] - for shard in safetensors_shards: - with safe_open(shard, framework="tf") as safetensors_archive: - tf_model, loading_info = load_pytorch_state_dict_in_tf2_model( - tf_model, - safetensors_archive, - tf_inputs=tf_inputs, - allow_missing_keys=allow_missing_keys, - output_loading_info=True, - _prefix=_prefix, - tf_to_pt_weight_rename=tf_to_pt_weight_rename, - ignore_mismatched_sizes=ignore_mismatched_sizes, - skip_logger_warnings=True, # We will emit merged warnings at the end - ) - all_loading_infos.append(loading_info) - # Now we just need to merge the loading info - # Keys are missing only if they're missing in *every* shard - missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos])) - # Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard - unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], []) - mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], []) - - _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__) - - if output_loading_info: - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - } - return tf_model, loading_info - - return tf_model - - -##################### -# TF 2.0 => PyTorch # -##################### - - -def load_tf2_checkpoint_in_pytorch_model( - pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False -): - """ - Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see - https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). - """ - try: - import tensorflow as tf # noqa: F401 - import torch # noqa: F401 - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " - "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." - ) - raise - - import transformers - - from .modeling_tf_utils import load_tf_weights - - logger.info(f"Loading TensorFlow weights from {tf_checkpoint_path}") - - # Instantiate and load the associated TF 2.0 model - tf_model_class_name = "TF" + pt_model.__class__.__name__ # Add "TF" at the beginning - tf_model_class = getattr(transformers, tf_model_class_name) - tf_model = tf_model_class(pt_model.config) - - if tf_inputs is None: - tf_inputs = tf_model.dummy_inputs - - if tf_inputs is not None: - tf_model(tf_inputs, training=False) # Make sure model is built - - load_tf_weights(tf_model, tf_checkpoint_path) - - return load_tf2_model_in_pytorch_model( - pt_model, tf_model, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info - ) - - -def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False, output_loading_info=False): - """Load TF 2.0 model in a pytorch model""" - weights = tf_model.weights - - return load_tf2_weights_in_pytorch_model( - pt_model, weights, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info - ) - - -def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False, output_loading_info=False): - """Load TF2.0 symbolic weights in a PyTorch model""" - try: - import tensorflow as tf # noqa: F401 - import torch # noqa: F401 - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " - "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." - ) - raise - - tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights} - return load_tf2_state_dict_in_pytorch_model( - pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info - ) - - -def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False): - import torch - - new_pt_params_dict = {} - current_pt_params_dict = dict(pt_model.named_parameters()) - - # Make sure we are able to load PyTorch base models as well as derived models (with heads) - # TF models always have a prefix, some of PyTorch models (base ones) don't - start_prefix_to_remove = "" - if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict): - start_prefix_to_remove = pt_model.base_model_prefix + "." - - # Build a map from potential PyTorch weight names to TF 2.0 Variables - tf_weights_map = {} - for name, tf_weight in tf_state_dict.items(): - pt_name, transpose = convert_tf_weight_name_to_pt_weight_name( - name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape - ) - tf_weights_map[pt_name] = (tf_weight, transpose) - - all_tf_weights = set(tf_weights_map.keys()) - loaded_pt_weights_data_ptr = {} - missing_keys_pt = [] - for pt_weight_name, pt_weight in current_pt_params_dict.items(): - # Handle PyTorch shared weight not duplicated in TF 2.0 - if pt_weight.data_ptr() in loaded_pt_weights_data_ptr and pt_weight.data_ptr() != 0: - new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()] - continue - - pt_weight_name_to_check = pt_weight_name - # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 - key_components = pt_weight_name.split(".") - name = None - if key_components[-3::2] == ["parametrizations", "original0"]: - name = key_components[-2] + "_g" - elif key_components[-3::2] == ["parametrizations", "original1"]: - name = key_components[-2] + "_v" - if name is not None: - key_components = key_components[:-3] + [name] - pt_weight_name_to_check = ".".join(key_components) - - # Find associated numpy array in pytorch model state dict - if pt_weight_name_to_check not in tf_weights_map: - if allow_missing_keys: - missing_keys_pt.append(pt_weight_name) - continue - - raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model") - - array, transpose = tf_weights_map[pt_weight_name_to_check] - - array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False) - - if numpy.isscalar(array): - array = numpy.array(array) - if not is_torch_tensor(array) and not is_numpy_array(array): - array = array.numpy() - if is_numpy_array(array): - # Convert to torch tensor - array = torch.from_numpy(array) - - new_pt_params_dict[pt_weight_name] = array - loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array - all_tf_weights.discard(pt_weight_name) - - missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False) - missing_keys += missing_keys_pt - - # Some models may have keys that are not in the state by design, removing them before needlessly warning - # the user. - if pt_model._keys_to_ignore_on_load_missing is not None: - for pat in pt_model._keys_to_ignore_on_load_missing: - missing_keys = [k for k in missing_keys if re.search(pat, k) is None] - - if pt_model._keys_to_ignore_on_load_unexpected is not None: - for pat in pt_model._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - "Some weights of the TF 2.0 model were not used when initializing the PyTorch model" - f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" - f" {pt_model.__class__.__name__} from a TF 2.0 model trained on another task or with another architecture" - " (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n- This IS" - f" NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect" - " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" - " TFBertForSequenceClassification model)." - ) - else: - logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n") - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model and are newly" - f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" - " use it for predictions and inference." - ) - else: - logger.warning( - f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n" - "If your task is similar to the task the model of the checkpoint was trained on, " - f"you can already use {pt_model.__class__.__name__} for predictions without further training." - ) - - logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}") - - if output_loading_info: - loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys} - return pt_model, loading_info - - return pt_model diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py deleted file mode 100644 index c7bb80656d1b..000000000000 --- a/src/transformers/modeling_tf_utils.py +++ /dev/null @@ -1,3529 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF general model utils.""" - -from __future__ import annotations - -import functools -import gc -import inspect -import json -import os -import pickle -import re -import warnings -from collections.abc import Mapping -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Union - -import h5py -import numpy as np -import tensorflow as tf -from packaging.version import parse - -from . import DataCollatorWithPadding, DefaultDataCollator -from .activations_tf import get_tf_activation -from .configuration_utils import PretrainedConfig -from .dynamic_module_utils import custom_object_save -from .generation import GenerationConfig, TFGenerationMixin -from .tf_utils import ( - convert_batch_encoding, - expand_1d, - load_attributes_from_hdf5_group, - save_attributes_to_hdf5_group, - shape_list, -) -from .utils import ( - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - TF2_WEIGHTS_INDEX_NAME, - TF2_WEIGHTS_NAME, - TF_WEIGHTS_NAME, - WEIGHTS_INDEX_NAME, - WEIGHTS_NAME, - ModelOutput, - PushToHubMixin, - cached_file, - download_url, - find_labels, - has_file, - is_offline_mode, - is_remote_url, - is_safetensors_available, - is_tf_symbolic_tensor, - logging, - requires_backends, - working_or_temp_dir, -) -from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files - - -if is_safetensors_available(): - from safetensors import safe_open - from safetensors.tensorflow import save_file as safe_save_file - -if TYPE_CHECKING: - from . import PreTrainedTokenizerBase - -logger = logging.get_logger(__name__) - -if "TF_USE_LEGACY_KERAS" not in os.environ: - os.environ["TF_USE_LEGACY_KERAS"] = "1" # Compatibility fix to make sure tf.keras stays at Keras 2 -elif os.environ["TF_USE_LEGACY_KERAS"] != "1": - logger.warning( - "Transformers is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. " - "This may result in unexpected behaviour or errors if Keras 3 objects are passed to Transformers models." - ) - -try: - import tf_keras as keras - from tf_keras import backend as K -except (ModuleNotFoundError, ImportError): - import keras - from keras import backend as K - - if parse(keras.__version__).major > 2: - raise ValueError( - "Your currently installed version of Keras is Keras 3, but this is not yet supported in " - "Transformers. Please install the backwards-compatible tf-keras package with " - "`pip install tf-keras`." - ) - - -tf_logger = tf.get_logger() - -TFModelInputType = Union[ - list[tf.Tensor], - list[np.ndarray], - dict[str, tf.Tensor], - dict[str, np.ndarray], - tf.Tensor, - np.ndarray, -] - - -def dummy_loss(y_true, y_pred): - if y_pred.shape.rank <= 1: - return y_pred - else: - reduction_axes = list(range(1, y_pred.shape.rank)) - return tf.reduce_mean(y_pred, axis=reduction_axes) - - -class TFModelUtilsMixin: - """ - A few utilities for `keras.Model`, to be used as a mixin. - """ - - def num_parameters(self, only_trainable: bool = False) -> int: - """ - Get the number of (optionally, trainable) parameters in the model. - - Args: - only_trainable (`bool`, *optional*, defaults to `False`): - Whether or not to return only the number of trainable parameters - - Returns: - `int`: The number of parameters. - """ - if only_trainable: - return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables)) - else: - return self.count_params() - - -def keras_serializable(cls): - """ - Decorate a Keras Layer class to support Keras serialization. - - This is done by: - - 1. Adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at - serialization time. - 2. Wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and - convert it to a config object for the actual layer initializer. - 3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not - need to be supplied in `custom_objects` in the call to `keras.models.load_model`. - - Args: - cls (a `keras.layers.Layers subclass`): - Typically a `TF.MainLayer` class in this project, in general must accept a `config` argument to its - initializer. - - Returns: - The same class object, with modifications for Keras deserialization. - """ - initializer = cls.__init__ - - config_class = getattr(cls, "config_class", None) - if config_class is None: - raise AttributeError("Must set `config_class` to use @keras_serializable") - - @functools.wraps(initializer) - def wrapped_init(self, *args, **kwargs): - config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None) - - if isinstance(config, dict): - config = config_class.from_dict(config) - initializer(self, config, *args, **kwargs) - elif isinstance(config, PretrainedConfig): - if len(args) > 0: - initializer(self, *args, **kwargs) - else: - initializer(self, config, *args, **kwargs) - else: - raise TypeError("Must pass either `config` (PretrainedConfig) or `config` (dict)") - - self._config = config - self._kwargs = kwargs - - cls.__init__ = wrapped_init - - if not hasattr(cls, "get_config"): - raise TypeError("Only use @keras_serializable on keras.layers.Layer subclasses") - if hasattr(cls.get_config, "_is_default"): - - def get_config(self): - cfg = super(cls, self).get_config() - cfg["config"] = self._config.to_dict() - cfg.update(self._kwargs) - return cfg - - cls.get_config = get_config - - cls._keras_serializable = True - if hasattr(keras.utils, "register_keras_serializable"): - cls = keras.utils.register_keras_serializable()(cls) - return cls - - -class TFCausalLanguageModelingLoss: - """ - Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token. - - - - Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. - - - """ - - def hf_compute_loss(self, labels, logits): - loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) - if self.config.tf_legacy_loss: - # make sure only labels that are not equal to -100 affect the loss - active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) - reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) - labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) - return loss_fn(labels, reduced_logits) - - # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway - unmasked_loss = loss_fn(tf.nn.relu(labels), logits) - # make sure only labels that are not equal to -100 affect the loss - loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype) - masked_loss = unmasked_loss * loss_mask - reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask) - return tf.reshape(reduced_masked_loss, (1,)) - - -class TFQuestionAnsweringLoss: - """ - Loss function suitable for question answering. - """ - - def hf_compute_loss(self, labels, logits): - loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) - start_loss = loss_fn(labels["start_position"], logits[0]) - end_loss = loss_fn(labels["end_position"], logits[1]) - - return (start_loss + end_loss) / 2.0 - - -class TFTokenClassificationLoss: - """ - Loss function suitable for token classification. - - - - Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. - - - """ - - def hf_compute_loss(self, labels, logits): - loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) - if tf.executing_eagerly(): # Data-dependent conditionals are forbidden in XLA - if tf.math.reduce_any(labels == -1): - tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") - - if self.config.tf_legacy_loss: - # make sure only labels that are not equal to -100 - # are taken into account as loss - if tf.math.reduce_any(labels == -1): - tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") - active_loss = tf.reshape(labels, (-1,)) != -1 - else: - active_loss = tf.reshape(labels, (-1,)) != -100 - reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) - labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) - - return loss_fn(labels, reduced_logits) - - # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway - unmasked_loss = loss_fn(tf.nn.relu(labels), logits) - # make sure only labels that are not equal to -100 or -1 - # are taken into account as loss - loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype) - # Avoid possible division by zero later - # Masked positions will have a loss of NaN because -100 and -1 are not valid labels - masked_loss = unmasked_loss * loss_mask - reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask) - return tf.reshape(reduced_masked_loss, (1,)) - - -class TFSequenceClassificationLoss: - """ - Loss function suitable for sequence classification. - """ - - def hf_compute_loss(self, labels, logits): - if logits.shape.rank == 1 or logits.shape[1] == 1: - loss_fn = keras.losses.MeanSquaredError(reduction=keras.losses.Reduction.NONE) - if labels.shape.rank == 1: - # MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that - labels = tf.expand_dims(labels, axis=-1) - else: - loss_fn = keras.losses.SparseCategoricalCrossentropy( - from_logits=True, reduction=keras.losses.Reduction.NONE - ) - - return loss_fn(labels, logits) - - -class TFMultipleChoiceLoss: - """Loss function suitable for multiple choice tasks.""" - - def hf_compute_loss(self, labels, logits): - loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) - return loss_fn(labels, logits) - - -class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss): - """ - Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens. - - - - Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. - - - """ - - -class TFNextSentencePredictionLoss: - """ - Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence. - - - - Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. - - - """ - - def hf_compute_loss(self, labels, logits): - loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) - if self.config.tf_legacy_loss: - # make sure only labels that are not equal to -100 - # are taken into account as loss - next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) - next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss) - next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss) - - return loss_fn(next_sentence_label, next_sentence_reduced_logits) - - # make sure only labels that are not equal to -100 - # are taken into account as loss - - # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway - unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits) - ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype) - # Just zero out samples where label is -100, no reduction - masked_ns_loss = unmasked_ns_loss * ns_loss_mask - - return masked_ns_loss - - -def booleans_processing(config, **kwargs): - """ - Process the input booleans of each model. - - Args: - config ([`PretrainedConfig`]): - The config of the running model. - **kwargs: - The boolean parameters - - Returns: - A dictionary with the proper values for each boolean - """ - final_booleans = {} - - # Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has - # `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`) - if "output_attentions" in kwargs: - final_booleans["output_attentions"] = ( - kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions - ) - final_booleans["output_hidden_states"] = ( - kwargs["output_hidden_states"] if kwargs["output_hidden_states"] is not None else config.output_hidden_states - ) - final_booleans["return_dict"] = kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict - - if "use_cache" in kwargs: - final_booleans["use_cache"] = ( - kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None) - ) - return final_booleans - - -def unpack_inputs(func): - """ - Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables - downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input - (common case in Keras). - - Args: - func (`callable`): - The callable function of the TensorFlow model. - - - Returns: - A callable that wraps the original `func` with the behavior described above. - """ - - original_signature = inspect.signature(func) - - @functools.wraps(func) - def run_call_with_unpacked_inputs(self, *args, **kwargs): - # isolates the actual `**kwargs` for the decorated function - kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)} - fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call} - fn_args_and_kwargs.update({"kwargs_call": kwargs_call}) - - # move any arg into kwargs, if they exist - fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args))) - - # Encoder Decoder models delegate the application of the configuration options to their inner models. - if "EncoderDecoder" in self.__class__.__name__: - config = None - else: - config = self.config - - unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs) - return func(self, **unpacked_inputs) - - # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This - # function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below - # Keras would attempt to check the first argument against the literal signature of the wrapper. - run_call_with_unpacked_inputs.__signature__ = original_signature - - return run_call_with_unpacked_inputs - - -def input_processing(func, config, **kwargs): - """ - Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input - has to be named accordingly to the parameters name, i.e. `input_ids = keras.Input(shape=(128,), dtype='int32', - name="input_ids")` otherwise the order of the tensors will not be guaranteed during the training. - - Args: - func (`callable`): - The callable function of the TensorFlow model. - config ([`PretrainedConfig`]): - The config of the running model. - **kwargs: - The inputs of the model. - - Returns: - Two lists, one for the missing layers, and another one for the unexpected layers. - """ - signature = dict(inspect.signature(func).parameters) - has_kwargs = bool(signature.pop("kwargs", None)) - signature.pop("self", None) - parameter_names = list(signature.keys()) - main_input_name = parameter_names[0] - main_input = kwargs.pop(main_input_name, None) - output = {} - allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray) - - if "inputs" in kwargs["kwargs_call"]: - warnings.warn( - "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.", - FutureWarning, - ) - - output["input_ids"] = kwargs["kwargs_call"].pop("inputs") - - if "decoder_cached_states" in kwargs["kwargs_call"]: - warnings.warn( - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use" - " `past_key_values` instead.", - FutureWarning, - ) - output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states") - - if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names: - warnings.warn( - "The `past` argument is deprecated and will be removed in a future version, use `past_key_values`" - " instead.", - FutureWarning, - ) - kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past") - elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names: - kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values") - - if has_kwargs: - output["kwargs"] = kwargs.pop("kwargs_call", {}) - else: - if len(kwargs["kwargs_call"]) > 0: - raise ValueError( - "The following keyword arguments are not supported by this model:" - f" {list(kwargs['kwargs_call'].keys())}." - ) - kwargs.pop("kwargs_call") - - for k, v in kwargs.items(): - if isinstance(v, allowed_types) or tf.is_tensor(v) or v is None: - output[k] = v - else: - raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") - - if isinstance(main_input, (tuple, list)): - for i, input in enumerate(main_input): - # EagerTensors don't allow to use the .name property so we check for a real Tensor - if is_tf_symbolic_tensor(input): - # Tensor names have always the pattern `name:id` then we check only the - # `name` part - tensor_name = input.name.split(":")[0] - - if tensor_name in parameter_names: - output[tensor_name] = input - else: - output[parameter_names[i]] = input - elif isinstance(input, allowed_types) or input is None: - output[parameter_names[i]] = input - else: - raise ValueError( - f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for" - f" {parameter_names[i]}." - ) - elif isinstance(main_input, Mapping): - if "inputs" in main_input: - warnings.warn( - "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`" - " instead.", - FutureWarning, - ) - - output["input_ids"] = main_input.pop("inputs") - - if "decoder_cached_states" in main_input: - warnings.warn( - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use" - " `past_key_values` instead.", - FutureWarning, - ) - output["past_key_values"] = main_input.pop("decoder_cached_states") - - for k, v in dict(main_input).items(): - if isinstance(v, allowed_types) or v is None: - output[k] = v - elif k not in parameter_names and "args" not in parameter_names: - logger.warning( - f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored." - ) - continue - else: - raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") - else: - if tf.is_tensor(main_input) or main_input is None: - output[main_input_name] = main_input - else: - raise ValueError( - f"Data of type {type(main_input)} is not allowed only {allowed_types} is accepted for" - f" {main_input_name}." - ) - - # Populates any unspecified argument with their default value, according to the signature. - for name in parameter_names: - if name not in list(output.keys()) and name != "args": - output[name] = kwargs.pop(name, signature[name].default) - - # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs) - # So to respect the proper output we have to add this exception - if "args" in output: - if output["args"] is not None and is_tf_symbolic_tensor(output["args"]): - tensor_name = output["args"].name.split(":")[0] - output[tensor_name] = output["args"] - else: - # `args` in this case is always the first parameter, then `input_ids` - output["input_ids"] = output["args"] - - del output["args"] - - if "kwargs" in output: - del output["kwargs"] - - cast_output = {} - for key, val in output.items(): - if isinstance(val, tf.Tensor) and val.dtype == tf.int64: - cast_output[key] = tf.cast(val, tf.int32) - elif isinstance(val, np.ndarray) and val.dtype == np.int64: - cast_output[key] = val.astype(np.int32) - else: - cast_output[key] = val - - output = cast_output - del cast_output - - if config is not None: - boolean_dict = { - k: v - for k, v in output.items() - if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] - } - - output.update( - booleans_processing( - config=config, - **boolean_dict, - ) - ) - - return output - - -def strip_model_name_and_prefix(name, _prefix=None): - if _prefix is not None and name.startswith(_prefix): - name = name[len(_prefix) :] - if name.startswith("/"): - name = name[1:] - if "model." not in name and len(name.split("/")) > 1: - name = "/".join(name.split("/")[1:]) - return name - - -def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME): - """ - Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a - given size. - - The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no - optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the - limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], - [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. - - - - If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will - have a size greater than `max_shard_size`. - - - - Args: - weights (`dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save. - max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): - The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit - (like `"5MB"`). - """ - max_shard_size = convert_file_size_to_int(max_shard_size) - - sharded_state_dicts = [] - current_block = [] - current_block_size = 0 - total_size = 0 - - for item in weights: - weight_size = item.numpy().size * item.dtype.size - - # If this weight is going to tip up over the maximal size, we split. - if current_block_size + weight_size > max_shard_size: - sharded_state_dicts.append(current_block) - current_block = [] - current_block_size = 0 - - current_block.append(item) - current_block_size += weight_size - total_size += weight_size - - # Add the last block - sharded_state_dicts.append(current_block) - - # If we only have one shard, we return it - if len(sharded_state_dicts) == 1: - return {weights_name: sharded_state_dicts[0]}, None - - # Otherwise, let's build the index - weight_map = {} - shards = {} - for idx, shard in enumerate(sharded_state_dicts): - shard_file = weights_name.replace(".h5", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.h5") - shard_file = shard_file.replace( - ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" - ) - shards[shard_file] = shard - for weight in shard: - weight_name = weight.name - weight_map[weight_name] = shard_file - - # Add the metadata - metadata = {"total_size": total_size} - index = {"metadata": metadata, "weight_map": weight_map} - return shards, index - - -def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None): - """ - This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load - the TF weights from the shard file accordingly to their names and shapes. - - This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being - loaded in the model. - - Args: - model (`keras.models.Model`): The model in which to load the checkpoint. - shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names. - ignore_mismatched_sizes`bool`, *optional`, defaults to `True`): - Whether or not to ignore the mismatch between the sizes - strict (`bool`, *optional*, defaults to `True`): - Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. - - Returns: - Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the - mismatched layers. - """ - - # Load the index - unexpected_keys = set() - saved_keys = set() - mismatched_keys = set() - - # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load - # the weight, we have to get rid of the first prefix of the name of the layer. - model_keys = set() - model_layer_map = {} - for i, k in enumerate(model.weights): - layer_name = k.name - if _prefix is not None and layer_name.startswith(_prefix): - layer_name = layer_name[len(_prefix) :] - layer_name = layer_name.lstrip("/") - if not ("model." in layer_name or len(layer_name.split("/")) == 1): - layer_name = "/".join(layer_name.split("/")[1:]) - model_keys.add(layer_name) - model_layer_map[layer_name] = i - - for shard_file in shard_files: - saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard( - model, - model_layer_map, - shard_file, - ignore_mismatched_sizes=ignore_mismatched_sizes, - _prefix=_prefix, - ) - saved_keys.update(saved_weight_names_set) - unexpected_keys.update(unexpected_keys_set) - mismatched_keys.update(mismatched_keys_set) - gc.collect() - - missing_keys = model_keys - saved_keys - if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): - error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" - if len(missing_keys) > 0: - str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) - error_message += f"\nMissing key(s): {str_missing_keys}." - if len(unexpected_keys) > 0: - str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) - error_message += f"\nMissing key(s): {str_unexpected_keys}." - raise RuntimeError(error_message) - - return missing_keys, unexpected_keys, mismatched_keys - - -def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): - """ - Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors. - Handles missing keys and unexpected keys. - - Args: - model (`keras.models.Model`): Model in which the weights are loaded - model_layer_map (`Dict`): A dictionary mapping the layer name to the index of the layer in the model. - resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded - ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys - - Returns: - `keras.models.Model`: Three lists, one for the layers that were found and successfully restored (from the - shard file), one for the mismatched layers, and another one for the unexpected layers. - """ - saved_weight_names_set = set() - saved_weights = {} - mismatched_keys = set() - unexpected_keys = set() - # Read the H5 file - try: - with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file: - # Retrieve the name of each layer from the H5 file - saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")) - weight_value_tuples = [] - - # Compute missing and unexpected sub layers - # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] - for layer_name in saved_h5_model_layers_name: - h5_layer_object = sharded_checkpoint_file[layer_name] - saved_weights[layer_name] = np.asarray(h5_layer_object) - - saved_weight_names_set.add(layer_name) - - if layer_name not in model_layer_map: - unexpected_keys.add(layer_name) - else: - symbolic_weight = model.weights[model_layer_map[layer_name]] - - saved_weight_value = saved_weights[layer_name] - # If the current weight is found - if saved_weight_value is not None: - # Check if the shape of the current weight and the one from the H5 file are different - if K.int_shape(symbolic_weight) != saved_weight_value.shape: - # If yes we reshape the weight from the H5 file accordingly to the current weight - # If the two shapes are not compatible we raise an issue - try: - array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) - except ValueError as e: - if ignore_mismatched_sizes: - mismatched_keys.add( - (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight)) - ) - continue - else: - raise e - else: - array = saved_weight_value - - # We create the tuple that will be loaded and add it to the final list - weight_value_tuples.append((symbolic_weight, array)) - - K.batch_set_value(weight_value_tuples) - - return saved_weight_names_set, unexpected_keys, mismatched_keys - - except Exception as e: - try: - with open(resolved_archive_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please install " - "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " - "you cloned." - ) - else: - raise ValueError( - f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained" - " model. Make sure you have saved the model properly." - ) from e - except (UnicodeDecodeError, ValueError): - raise OSError( - f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' " - f"at '{resolved_archive_file}'. " - "If you tried to load a TF model from a sharded checkpoint, you should try converting the model " - "by loading it in pytorch and saving it locally. A conversion script should be released soon." - ) - - -def load_tf_sharded_weights_from_safetensors( - model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None -): - """ - This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint. - Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and - shapes. - - This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being - loaded in the model. - - Args: - model (`keras.models.Model`): The model in which to load the checkpoint. - shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names. - ignore_mismatched_sizes`bool`, *optional`, defaults to `True`): - Whether or not to ignore the mismatch between the sizes - strict (`bool`, *optional*, defaults to `True`): - Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. - - Returns: - Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the - mismatched layers. - """ - - # Load the index - unexpected_keys = set() - all_missing_keys = [] - mismatched_keys = set() - - for shard_file in shard_files: - missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors( - model, - shard_file, - ignore_mismatched_sizes=ignore_mismatched_sizes, - _prefix=_prefix, - ) - all_missing_keys.append(set(missing_layers)) - unexpected_keys.update(unexpected_layers) - mismatched_keys.update(mismatched_layers) - gc.collect() - missing_keys = set.intersection(*all_missing_keys) - - if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): - error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" - if len(missing_keys) > 0: - str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) - error_message += f"\nMissing key(s): {str_missing_keys}." - if len(unexpected_keys) > 0: - str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) - error_message += f"\nMissing key(s): {str_unexpected_keys}." - raise RuntimeError(error_message) - - return missing_keys, unexpected_keys, mismatched_keys - - -def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): - """ - Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and - shapes. - - Args: - model (`keras.models.Model`): - The model to load the weights into. - resolved_archive_file (`str`): - The location of the H5 file. - ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): - Whether or not to ignore weights with shapes that don't match between the checkpoint of the model. - - Returns: - Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the - mismatched layers. - """ - if resolved_archive_file.endswith(".safetensors"): - load_function = load_tf_weights_from_safetensors - else: - load_function = load_tf_weights_from_h5 - - return load_function( - model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, _prefix=_prefix - ) - - -def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): - mismatched_layers = [] - - # Read the H5 file - with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file: - # Retrieve the name of each layer from the H5 file - saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")) - - # Find the missing layers from the high level list of layers - missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name) - - # Find the unexpected layers from the high level list of layers - unexpected_layers = list(saved_h5_model_layers_name - {layer.name for layer in model.layers}) - saved_weight_names_set = set() - symbolic_weights_names = set() - weight_value_tuples = [] - - # Compute missing and unexpected sub layers - # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] - for layer in model.layers: - # if layer_name from the H5 file belongs to the layers from the instantiated model - if layer.name in saved_h5_model_layers_name: - # Get the H5 layer object from its name - h5_layer_object = sharded_checkpoint_file[layer.name] - # Get all the weights as a list from the layer object - symbolic_weights = layer.trainable_weights + layer.non_trainable_weights - saved_weights = {} - - # Create a dict from the H5 saved model that looks like {"weight_name": weight_value} - # And a set with only the names - for weight_name in load_attributes_from_hdf5_group(h5_layer_object, "weight_names"): - # TF names always start with the model name so we ignore it - name = "/".join(weight_name.split("/")[1:]) - - if _prefix is not None: - name = _prefix + "/" + name - - saved_weights[name] = np.asarray(h5_layer_object[weight_name]) - - # Add the updated name to the final list for computing missing/unexpected values - saved_weight_names_set.add(name) - - # Loop over each weights from the instantiated model and compare with the weights from the H5 file - for symbolic_weight in symbolic_weights: - # TF names always start with the model name so we ignore it - if _prefix is not None: - delimiter = len(_prefix.split("/")) - symbolic_weight_name = "/".join( - symbolic_weight.name.split("/")[:delimiter] - + symbolic_weight.name.split("/")[delimiter + 1 :] - ) - else: - symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) - - # here we check if the current weight is among the weights from the H5 file - # If yes, get the weight_value of the corresponding weight from the H5 file - # If not, make the value to None - saved_weight_value = saved_weights.get(symbolic_weight_name) - - # Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's - # `model.shared/embeddings:0` are stored as `model.shared/weights:0`) - if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"): - symbolic_weight_name = symbolic_weight_name[:-12] + "weight:0" - saved_weight_value = saved_weights.get(symbolic_weight_name) - - # Add the updated name to the final list for computing missing/unexpected values - symbolic_weights_names.add(symbolic_weight_name) - - # If the current weight is found - if saved_weight_value is not None: - # Check if the shape of the current weight and the one from the H5 file are different - if K.int_shape(symbolic_weight) != saved_weight_value.shape: - # If yes we reshape the weight from the H5 file accordingly to the current weight - # If the two shapes are not compatible we raise an issue - try: - array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) - except ValueError as e: - if ignore_mismatched_sizes: - mismatched_layers.append( - (symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight)) - ) - continue - else: - raise e - else: - array = saved_weight_value - - # We create the tuple that will be loaded and add it to the final list - weight_value_tuples.append((symbolic_weight, array)) - - # Load all the weights - K.batch_set_value(weight_value_tuples) - - # Compute the missing and unexpected layers - missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) - unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) - - return missing_layers, unexpected_layers, mismatched_layers - - -def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): - # Read the safetensors file - with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: - mismatched_layers = [] - weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights] - loaded_weight_names = list(safetensors_archive.keys()) - # Find the missing layers from the high level list of layers - missing_layers = list(set(weight_names) - set(loaded_weight_names)) - # Find the unexpected layers from the high level list of layers - unexpected_layers = list(set(loaded_weight_names) - set(weight_names)) - - for weight in model.weights: - weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix) - if weight_name in loaded_weight_names: - weight_value = safetensors_archive.get_tensor(weight_name) - # Check if the shape of the current weight and the one from the H5 file are different - if K.int_shape(weight) != weight_value.shape: - # If yes we reshape the weight from the H5 file accordingly to the current weight - # If the two shapes are not compatible we raise an issue - try: - weight_value = tf.reshape(weight_value, K.int_shape(weight)) - except (ValueError, tf.errors.InvalidArgumentError) as e: - if ignore_mismatched_sizes: - mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight))) - continue - else: - raise e - - K.set_value(weight, weight_value) # weight.assign() might break if weight is a DTensor - return missing_layers, unexpected_layers, mismatched_layers - - -def init_copy_embeddings(old_embeddings, new_num_tokens): - r""" - This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case - new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be - kept or not. Example: - - - if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4] - - - mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1] - - if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5] - - - mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4] - """ - old_num_tokens, old_embedding_dim = shape_list(old_embeddings) - size_diff = new_num_tokens - old_num_tokens - - # initialize new embeddings - # Copy token embeddings from the previous ones - if tf.math.greater(size_diff, 0): - # if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size - # and we create a mask to properly identify the padded values and be replaced by the values of the newly created - # embeddings - current_weights = tf.pad( - old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1 - ) - num_tokens_to_copy = min(old_num_tokens, new_num_tokens) - mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True) - mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False) - else: - # if the new size if lower than the old one, we take the current embeddings until the new size - current_weights = tf.slice( - old_embeddings.value(), - tf.convert_to_tensor([0, 0]), - tf.convert_to_tensor([new_num_tokens, old_embedding_dim]), - ) - mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True) - - return mask, current_weights - - -class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushToHubMixin): - r""" - Base class for all TF models. - - [`TFPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, - downloading and saving models as well as a few methods common to all models to: - - - resize the input embeddings, - - prune heads in the self-attention heads. - - Class attributes (overridden by derived classes): - - - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class - for this model architecture. - - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived - classes of the same architecture adding modules on top of the base model. - - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP - models, `pixel_values` for vision models and `input_values` for speech models). - """ - - config_class = None - base_model_prefix = "" - main_input_name = "input_ids" - _auto_class = None - _using_dummy_loss = None - _label_to_output_map = None - - # a list of re pattern of tensor names to ignore from the model when loading the model weights - # (and avoid unnecessary warnings). - _keys_to_ignore_on_load_missing = None - # a list of re pattern of tensor names to ignore from the weights when loading the model weights - # (and avoid unnecessary warnings). - _keys_to_ignore_on_load_unexpected = None - _requires_load_weight_prefix = False - - @property - def dummy_inputs(self) -> dict[str, tf.Tensor]: - """ - Dummy inputs to build the network. - - Returns: - `dict[str, tf.Tensor]`: The dummy inputs. - """ - dummies = {} - for key, spec in self.input_signature.items(): - # 2 is the most correct arbitrary size. I will not be taking questions - dummy_shape = [dim if dim is not None else 2 for dim in spec.shape] - if spec.shape[0] is None: - # But let's make the batch size 1 to save memory anyway - dummy_shape[0] = 1 - dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype) - if key == "token_type_ids": - # Some models have token_type_ids but with a vocab_size of 1 - dummies[key] = tf.zeros_like(dummies[key]) - if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters: - if "encoder_hidden_states" not in dummies: - if self.main_input_name == "input_ids": - dummies["encoder_hidden_states"] = tf.ones( - shape=(1, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states" - ) - else: - raise NotImplementedError( - "Model has cross-attention but we couldn't infer the shape for the encoder hidden states. Please manually override dummy_inputs!" - ) - return dummies - - def build_in_name_scope(self): - with tf.name_scope(self.name): - self.build(input_shape=None) - - @property - def framework(self) -> str: - """ - :str: Identifies that this is a TensorFlow model. - """ - return "tf" - - def build(self, input_shape=None): - pass # This is just here to make sure we don't call the superclass build() - - def __init__(self, config, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - if not isinstance(config, PretrainedConfig): - raise TypeError( - f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " - "`PretrainedConfig`. To create a model from a pretrained model use " - f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" - ) - # Save config and origin of the pretrained weights if given in model - self.config = config - self.name_or_path = config.name_or_path - self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None - self._set_save_spec(self.input_signature) - logger.warning_once( - "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " - "recommend migrating to PyTorch classes or pinning your version of Transformers." - ) - - def get_config(self): - return self.config.to_dict() - - @functools.wraps(keras.Model.fit) - def fit(self, *args, **kwargs): - args, kwargs = convert_batch_encoding(*args, **kwargs) - return super().fit(*args, **kwargs) - - @functools.wraps(keras.Model.train_on_batch) - def train_on_batch(self, *args, **kwargs): - args, kwargs = convert_batch_encoding(*args, **kwargs) - return super().train_on_batch(*args, **kwargs) - - @functools.wraps(keras.Model.test_on_batch) - def test_on_batch(self, *args, **kwargs): - args, kwargs = convert_batch_encoding(*args, **kwargs) - return super().test_on_batch(*args, **kwargs) - - @functools.wraps(keras.Model.predict_on_batch) - def predict_on_batch(self, *args, **kwargs): - args, kwargs = convert_batch_encoding(*args, **kwargs) - return super().predict_on_batch(*args, **kwargs) - - @functools.wraps(keras.Model.predict) - def predict(self, *args, **kwargs): - args, kwargs = convert_batch_encoding(*args, **kwargs) - return super().predict(*args, **kwargs) - - @functools.wraps(keras.Model.evaluate) - def evaluate(self, *args, **kwargs): - args, kwargs = convert_batch_encoding(*args, **kwargs) - return super().evaluate(*args, **kwargs) - - @classmethod - def from_config(cls, config, **kwargs): - if isinstance(config, PretrainedConfig): - return cls._from_config(config, **kwargs) - return cls._from_config(cls.config_class.from_dict(config, **kwargs)) - - @classmethod - def _from_config(cls, config, **kwargs): - """ - All context managers that the model should be initialized under go here. - """ - return cls(config, **kwargs) - - def get_head_mask(self, head_mask: tf.Tensor | None, num_hidden_layers: int) -> tf.Tensor: - """ - Prepare the head mask if needed. - - Args: - head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): - The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). - num_hidden_layers (`int`): - The number of hidden layers in the model. - - Returns: - `tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with - `[None]` for each layer. - """ - if head_mask is not None: - head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) - else: - head_mask = [None] * num_hidden_layers - - return head_mask - - def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): - """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" - if head_mask.shape.rank == 1: - head_mask = head_mask[None, None, :, None, None] - head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0) - elif head_mask.shape.rank == 2: - head_mask = head_mask[:, None, :, None, None] - assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" - head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility - return head_mask - - @tf.function - def serving(self, inputs): - """ - Args: - Method used for serving the model. Does not have a specific signature, but will be specialized as concrete - functions when saving with `save_pretrained`. - inputs (`dict[str, tf.Tensor]`): - The input of the saved model as a dictionary of tensors. - """ - output = self.call(inputs) - - return self.serving_output(output) - - @property - def input_signature(self) -> dict[str, tf.TensorSpec]: - """ - This property should return a dict mapping input names to tf.TensorSpec objects, representing the expected - shape and dtype for model inputs. It is used for both serving and for generating dummy inputs. - """ - model_inputs = list(inspect.signature(self.call).parameters) - sig = {} - if "input_ids" in model_inputs: - if self.__class__.__name__.endswith("ForMultipleChoice"): - text_dims = 3 - else: - text_dims = 2 - for input_name in ( - "input_ids", - "attention_mask", - "token_type_ids", - "decoder_input_ids", - "decoder_attention_mask", - ): - if input_name in model_inputs: - sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name) - if "pixel_values" in model_inputs: - pixel_values_shape = [None, None, None, None] - if hasattr(self.config, "vision_config"): - vision_config = self.config.vision_config - else: - vision_config = self.config - if hasattr(vision_config, "num_channels"): - pixel_values_shape[1] = vision_config.num_channels - else: - raise NotImplementedError( - "Could not infer number of channels from config, please override input_signature to specify input shapes." - ) - if hasattr(vision_config, "image_size"): - pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size - elif hasattr(vision_config, "input_size"): - pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size - else: - raise NotImplementedError( - "Could not infer input image shape from config, please override input_signature to specify input shapes." - ) - sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values") - if "input_features" in model_inputs: - raise NotImplementedError("Audio models need a manually defined input_signature") - return sig - - def serving_output(self, output): - """ - Prepare the output of the saved model. Can be overridden if specific serving modifications are required. - """ - if not isinstance(output, ModelOutput): - return output - for key in output: - if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False): - output[key] = None - elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False): - output[key] = None - elif key == "past_key_values" and not getattr(self.config, "use_cache", False): - output[key] = None - elif key == "cross_attentions" and not ( - getattr(self.config, "output_attentions", False) and getattr(self.config, "add_cross_attention", False) - ): - output[key] = None - if isinstance(output[key], (tuple, list)): - try: - output[key] = tf.convert_to_tensor(output[key]) - except (ValueError, tf.errors.InvalidArgumentError): - pass # Layers may not have the same dimensions - return output - - @classmethod - def can_generate(cls) -> bool: - """ - Returns whether this model can generate sequences with `.generate()`. - - Returns: - `bool`: Whether this model can generate sequences with `.generate()`. - """ - # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. - # Alternatively, the model can also have a custom `generate` function. - if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): - return False - return True - - def get_input_embeddings(self) -> keras.layers.Layer: - """ - Returns the model's input embeddings layer. - - Returns: - `tf.Variable`: The embeddings layer mapping vocabulary to hidden states. - """ - main_layer = getattr(self, self.base_model_prefix, self) - - if main_layer is not self: - return main_layer.get_input_embeddings() - else: - raise NotImplementedError - - def _save_checkpoint(self, checkpoint_dir, epoch): - if not os.path.isdir(checkpoint_dir): - os.mkdir(checkpoint_dir) - # We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer - # state for us, because it requires special handling for objects like custom losses, which we use - # internally and which users are likely to use too - weights_path = os.path.join(checkpoint_dir, "weights.h5") - self.save_weights(weights_path) - extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()} - extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle") - with open(extra_data_path, "wb") as f: - pickle.dump(extra_data, f) - - def prepare_tf_dataset( - self, - dataset: datasets.Dataset, # noqa:F821 - batch_size: int = 8, - shuffle: bool = True, - tokenizer: PreTrainedTokenizerBase | None = None, - collate_fn: Callable | None = None, - collate_fn_args: dict[str, Any] | None = None, - drop_remainder: bool | None = None, - prefetch: bool = True, - ): - """ - Wraps a HuggingFace [`~datasets.Dataset`] as a `tf.data.Dataset` with collation and batching. This method is - designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without - further modification. The method will drop columns from the dataset if they don't match input names for the - model. If you want to specify the column names to return rather than using the names that match this model, we - recommend using `Dataset.to_tf_dataset()` instead. - - Args: - dataset (`Any`): - A [~`datasets.Dataset`] to be wrapped as a `tf.data.Dataset`. - batch_size (`int`, *optional*, defaults to 8): - The size of batches to return. - shuffle (`bool`, defaults to `True`): - Whether to return samples from the dataset in random order. Usually `True` for training datasets and - `False` for validation/test datasets. - tokenizer ([`PreTrainedTokenizerBase`], *optional*): - A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific - `collate_fn` is passed instead. - collate_fn (`Callable`, *optional*): - A function that collates samples from the dataset into a single batch. Defaults to - `DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is - passed. - collate_fn_args (`dict[str, Any]`, *optional*): - A dict of arguments to pass to the `collate_fn` alongside the list of samples. - drop_remainder (`bool`, *optional*): - Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults - to the same setting as `shuffle`. - prefetch (`bool`, defaults to `True`): - Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for - performance, but can be disabled in edge cases. - - - Returns: - `Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API. - """ - requires_backends(self, ["datasets"]) - import datasets - - if collate_fn is None: - if tokenizer is None: - collate_fn = DefaultDataCollator(return_tensors="np") - else: - collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np") - if collate_fn_args is None: - collate_fn_args = {} - - if not isinstance(dataset, datasets.Dataset): - raise TypeError("Dataset argument should be a datasets.Dataset!") - model_inputs = list(inspect.signature(self.call).parameters) - model_labels = find_labels(self.__class__) - if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()): - output_signature, _ = dataset._get_output_signature( - dataset, - batch_size=None, - collate_fn=collate_fn, - collate_fn_args=collate_fn_args, - cols_to_retain=model_inputs, - ) - else: - # TODO Matt: This is a workaround for older versions of datasets that are missing the `cols_to_retain` - # argument. We should remove this once the minimum supported version of datasets is > 2.3.2 - unwanted_columns = [ - feature - for feature in dataset.features - if feature not in model_inputs and feature not in ("label_ids", "label") - ] - dataset = dataset.remove_columns(unwanted_columns) - output_signature, _ = dataset._get_output_signature( - dataset, batch_size=None, collate_fn=collate_fn, collate_fn_args=collate_fn_args - ) - output_columns = list(output_signature.keys()) - feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels] - label_cols = [col for col in output_columns if col in model_labels] - - # Backwards compatibility for older versions of datasets. Previously, if `columns` or `label_cols` - # were a single element list, the returned element spec would be a single element. Now, passing [feature] - # will return a dict structure {"feature": feature}, and passing a single string will return a single element. - feature_cols = feature_cols[0] if len(feature_cols) == 1 else feature_cols - label_cols = label_cols[0] if len(label_cols) == 1 else label_cols - - if drop_remainder is None: - drop_remainder = shuffle - tf_dataset = dataset.to_tf_dataset( - columns=feature_cols, - label_cols=label_cols, - batch_size=batch_size, - shuffle=shuffle, - drop_remainder=drop_remainder, - collate_fn=collate_fn, - collate_fn_args=collate_fn_args, - prefetch=prefetch, - ) - return tf_dataset - - def compile( - self, - optimizer="rmsprop", - loss="auto_with_warning", - metrics=None, - loss_weights=None, - weighted_metrics=None, - run_eagerly=None, - steps_per_execution=None, - **kwargs, - ): - """ - This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss - function themselves. - """ - if loss in ("auto_with_warning", "passthrough"): # "passthrough" for workflow backward compatibility - logger.info( - "No loss specified in compile() - the model's internal loss computation will be used as the " - "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! " - "To disable this behaviour please pass a loss argument, or explicitly pass " - "`loss=None` if you do not want your model to compute a loss. You can also specify `loss='auto'` to " - "get the internal loss without printing this info string." - ) - loss = "auto" - if loss == "auto": - loss = dummy_loss - self._using_dummy_loss = True - else: - self._using_dummy_loss = False - parent_args = list(inspect.signature(keras.Model.compile).parameters.keys()) - # This argument got renamed, we need to support both versions - if "steps_per_execution" in parent_args: - super().compile( - optimizer=optimizer, - loss=loss, - metrics=metrics, - loss_weights=loss_weights, - weighted_metrics=weighted_metrics, - run_eagerly=run_eagerly, - steps_per_execution=steps_per_execution, - **kwargs, - ) - else: - super().compile( - optimizer=optimizer, - loss=loss, - metrics=metrics, - loss_weights=loss_weights, - weighted_metrics=weighted_metrics, - run_eagerly=run_eagerly, - experimental_steps_per_execution=steps_per_execution, - **kwargs, - ) - - def compute_loss(self, *args, **kwargs): - if hasattr(keras.Model, "compute_loss"): - # This will be true in TF 2.8 or greater - return super().compute_loss(*args, **kwargs) - else: - warnings.warn( - "The old compute_loss method is deprecated as it conflicts with the Keras compute_loss " - "method added in TF 2.8. If you want the original HF compute_loss, please call " - "hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, " - "calling compute_loss() will get the Keras method instead.", - FutureWarning, - ) - return self.hf_compute_loss(*args, **kwargs) - - def get_label_to_output_name_mapping(self): - arg_names = list(inspect.signature(self.call).parameters) - if self._label_to_output_map is not None: - return self._label_to_output_map - elif "start_positions" in arg_names: - return {"start_positions": "start_logits", "end_positions": "end_logits"} - elif "sentence_order_label" in arg_names: - return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"} - elif "next_sentence_label" in arg_names: - return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"} - elif "mc_labels" in arg_names: - return {"labels": "logits", "mc_labels": "mc_logits"} - else: - return {} - - def train_step(self, data): - """ - A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models - and supports directly training on the loss output head. In addition, it ensures input keys are copied to the - labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure - that they are available to the model during the forward pass. - """ - - # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map` - arg_names = list(inspect.signature(self.call).parameters) - label_kwargs = find_labels(self.__class__) - label_to_output = self.get_label_to_output_name_mapping() - output_to_label = {val: key for key, val in label_to_output.items()} - if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"): - # Newer TF train steps leave this out - data = expand_1d(data) - x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) - # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify - # them during input/label pre-processing. This avoids surprising the user by wrecking their data. - # In addition, modifying mutable Python inputs makes XLA compilation impossible. - if isinstance(x, dict): - x = x.copy() - if isinstance(y, dict): - y = y.copy() - - # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, - # if those keys are not already present in the input dict - if self._using_dummy_loss and y is not None: - # If y is a tensor and the model only has one label-like input, map y to that input - if len(label_kwargs) == 1 and isinstance(y, tf.Tensor): - if isinstance(x, tf.Tensor): - x = {arg_names[0]: x} - label_kwarg = next(iter(label_kwargs)) - if label_kwarg not in x: - x[label_kwarg] = y - # Otherwise, copy keys from y to x as long as they weren't already present in x - elif isinstance(y, dict): - if isinstance(x, tf.Tensor): - x = {arg_names[0]: x} - for key, val in y.items(): - if key in arg_names and key not in x: - x[key] = val - elif output_to_label.get(key) in arg_names and key not in x: - x[output_to_label[key]] = val - if y is None: - y = {key: val for key, val in x.items() if key in label_kwargs} - if not y and not self._using_dummy_loss: - raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!") - - if isinstance(y, dict): - # Rename labels at this point to match output heads - y = {label_to_output.get(key, key): val for key, val in y.items()} - - # Run forward pass. - with tf.GradientTape() as tape: - if self._using_dummy_loss and "return_loss" in arg_names: - y_pred = self(x, training=True, return_loss=True) - else: - y_pred = self(x, training=True) - if self._using_dummy_loss: - loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses) - else: - loss = None - - # This next block matches outputs to label keys. Tensorflow's standard method for doing this - # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors) - if isinstance(y, dict) and len(y) == 1: - if list(y.keys())[0] in y_pred: - y_pred = y_pred[list(y.keys())[0]] - elif list(y_pred.keys())[0] == "loss": - y_pred = y_pred[1] - else: - y_pred = y_pred[0] - _, y = y.popitem() - elif isinstance(y, dict): - # If the labels are a dict, match keys from the output by name - y_pred = {key: val for key, val in y_pred.items() if key in y} - elif isinstance(y, (tuple, list)): - # If the labels are a tuple/list, match keys to the output by order, skipping the loss. - if list(y_pred.keys())[0] == "loss": - y_pred = y_pred.to_tuple()[1:] - else: - y_pred = y_pred.to_tuple() - y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems - else: - # If the labels are a single tensor, match them to the first non-loss tensor in the output - if list(y_pred.keys())[0] == "loss": - y_pred = y_pred[1] - else: - y_pred = y_pred[0] - - if loss is None: - loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) - - # Run backwards pass. - self.optimizer.minimize(loss, self.trainable_variables, tape=tape) - - self.compiled_metrics.update_state(y, y_pred, sample_weight) - # Collect metrics to return - return_metrics = {} - for metric in self.metrics: - result = metric.result() - if isinstance(result, dict): - return_metrics.update(result) - else: - return_metrics[metric.name] = result - return return_metrics - - def test_step(self, data): - """ - A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models - and supports directly training on the loss output head. In addition, it ensures input keys are copied to the - labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure - that they are available to the model during the forward pass. - """ - # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map` - arg_names = list(inspect.signature(self.call).parameters) - label_kwargs = find_labels(self.__class__) - label_to_output = self.get_label_to_output_name_mapping() - output_to_label = {val: key for key, val in label_to_output.items()} - if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"): - # Newer versions leave this out - data = expand_1d(data) - x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) - # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify - # them during input/label pre-processing. This avoids surprising the user by wrecking their data. - # In addition, modifying mutable Python inputs makes XLA compilation impossible. - if isinstance(x, dict): - x = x.copy() - if isinstance(y, dict): - y = y.copy() - - # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, - # if those keys are not already present in the input dict - if self._using_dummy_loss and y is not None: - arg_names = list(inspect.signature(self.call).parameters) - # If y is a tensor and the model only has one label-like input, map y to that input - if len(label_kwargs) == 1 and isinstance(y, tf.Tensor): - if isinstance(x, tf.Tensor): - x = {arg_names[0]: x} - label_kwarg = next(iter(label_kwargs)) - if label_kwarg not in x: - x[label_kwarg] = y - # Otherwise, copy keys from y to x as long as they weren't already present in x - elif isinstance(y, dict): - if isinstance(x, tf.Tensor): - x = {arg_names[0]: x} - for key, val in y.items(): - if key in arg_names and key not in x: - x[key] = val - elif output_to_label.get(key) in arg_names and key not in x: - x[output_to_label[key]] = val - if y is None: - y = {key: val for key, val in x.items() if key in label_kwargs} - if not y and not self._using_dummy_loss: - raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!") - - if isinstance(y, dict): - # Rename labels at this point to match output heads - y = {label_to_output.get(key, key): val for key, val in y.items()} - - # Run forward pass. - if self._using_dummy_loss and "return_loss" in arg_names: - y_pred = self(x, return_loss=True, training=False) - else: - y_pred = self(x, training=False) - if self._using_dummy_loss: - loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses) - else: - loss = None - - # This next block matches outputs to label keys. Tensorflow's standard method for doing this - # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors) - if isinstance(y, dict) and len(y) == 1: - if list(y.keys())[0] in y_pred: - y_pred = y_pred[list(y.keys())[0]] - elif list(y_pred.keys())[0] == "loss": - y_pred = y_pred[1] - else: - y_pred = y_pred[0] - _, y = y.popitem() - elif isinstance(y, dict): - # If the labels are a dict, match keys from the output by name - y_pred = {key: val for key, val in y_pred.items() if key in y} - elif isinstance(y, (tuple, list)): - # If the labels are a tuple/list, match keys to the output by order, skipping the loss. - if list(y_pred.keys())[0] == "loss": - y_pred = y_pred.to_tuple()[1:] - else: - y_pred = y_pred.to_tuple() - y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems - else: - # If the labels are a single tensor, match them to the first non-loss tensor in the output - if list(y_pred.keys())[0] == "loss": - y_pred = y_pred[1] - else: - y_pred = y_pred[0] - - if loss is None: - loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) - - self.compiled_metrics.update_state(y, y_pred, sample_weight) - # Collect metrics to return - return_metrics = {} - for metric in self.metrics: - result = metric.result() - if isinstance(result, dict): - return_metrics.update(result) - else: - return_metrics[metric.name] = result - return return_metrics - - def create_model_card( - self, - output_dir, - model_name: str, - language: str | None = None, - license: str | None = None, - tags: str | None = None, - finetuned_from: str | None = None, - tasks: str | None = None, - dataset_tags: str | list[str] | None = None, - dataset: str | list[str] | None = None, - dataset_args: str | list[str] | None = None, - ): - """ - Creates a draft of a model card using the information available to the `Trainer`. - - Args: - output_dir (`str` or `os.PathLike`): - The folder in which to create the model card. - model_name (`str`, *optional*): - The name of the model. - language (`str`, *optional*): - The language of the model (if applicable) - license (`str`, *optional*): - The license of the model. Will default to the license of the pretrained model used, if the original - model given to the `Trainer` comes from a repo on the Hub. - tags (`str` or `list[str]`, *optional*): - Some tags to be included in the metadata of the model card. - finetuned_from (`str`, *optional*): - The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo - of the original model given to the `Trainer` (if it comes from the Hub). - tasks (`str` or `list[str]`, *optional*): - One or several task identifiers, to be included in the metadata of the model card. - dataset_tags (`str` or `list[str]`, *optional*): - One or several dataset tags, to be included in the metadata of the model card. - dataset (`str` or `list[str]`, *optional*): - One or several dataset identifiers, to be included in the metadata of the model card. - dataset_args (`str` or `list[str]`, *optional*): - One or several dataset arguments, to be included in the metadata of the model card. - """ - # Avoids a circular import by doing this when necessary. - from .modelcard import TrainingSummary # tests_ignore - - training_summary = TrainingSummary.from_keras( - self, - keras_history=self.history, - language=language, - license=license, - tags=tags, - model_name=model_name, - finetuned_from=finetuned_from, - tasks=tasks, - dataset_tags=dataset_tags, - dataset=dataset, - dataset_args=dataset_args, - ) - model_card = training_summary.to_model_card() - with open(os.path.join(output_dir, "README.md"), "w") as f: - f.write(model_card) - - def set_input_embeddings(self, value): - """ - Set model's input embeddings - - Args: - value (`tf.Variable`): - The new weights mapping hidden states to vocabulary. - """ - main_layer = getattr(self, self.base_model_prefix) - - if main_layer is None: - raise NotImplementedError("The model does not implements the base_model_prefix attribute.") - - try: - main_layer.set_input_embeddings(value) - except AttributeError: - logger.info("Building the model") - self.build_in_name_scope() - main_layer.set_input_embeddings(value) - - def get_output_embeddings(self) -> None | keras.layers.Layer: - """ - Returns the model's output embeddings - - Returns: - `tf.Variable`: The new weights mapping vocabulary to hidden states. - """ - if self.get_lm_head() is not None: - lm_head = self.get_lm_head() - - try: - return lm_head.get_output_embeddings() - except AttributeError: - logger.info("Building the model") - self.build_in_name_scope() - - return lm_head().get_output_embeddings() - - return None # Overwrite for models with output embeddings - - def set_output_embeddings(self, value): - """ - Set model's output embeddings - - Args: - value (`tf.Variable`): - The new weights mapping hidden states to vocabulary. - """ - if self.get_lm_head() is not None: - lm_head = self.get_lm_head() - try: - lm_head.set_output_embeddings(value) - except AttributeError: - logger.info("Building the model") - self.build_in_name_scope() - lm_head.set_output_embeddings(value) - - def get_output_layer_with_bias(self) -> None | keras.layers.Layer: - """ - Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the - embeddings - - Return: - `keras.layers.Layer`: The layer that handles the bias, None if not an LM model. - """ - warnings.warn( - "The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning - ) - return self.get_lm_head() - - def get_prefix_bias_name(self) -> None | str: - """ - Get the concatenated _prefix name of the bias from the model name to the parent layer - - Return: - `str`: The _prefix name of the bias. - """ - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return None - - def get_bias(self) -> None | dict[str, tf.Variable]: - """ - Dict of bias attached to an LM head. The key represents the name of the bias attribute. - - Return: - `tf.Variable`: The weights representing the bias, None if not an LM model. - """ - if self.get_lm_head() is not None: - lm_head = self.get_lm_head() - try: - return lm_head.get_bias() - except AttributeError: - self.build_in_name_scope() - - return lm_head.get_bias() - return None - - def set_bias(self, value): - """ - Set all the bias in the LM head. - - Args: - value (`dict[tf.Variable]`): - All the new bias attached to an LM head. - """ - if self.get_lm_head() is not None: - lm_head = self.get_lm_head() - try: - lm_head.set_bias(value) - except AttributeError: - self.build_in_name_scope() - lm_head.set_bias(value) - - def get_lm_head(self) -> keras.layers.Layer: - """ - The LM Head layer. This method must be overwritten by all the models that have a lm head. - - Return: - `keras.layers.Layer`: The LM head layer if the model has one, None if not. - """ - return None - - def resize_token_embeddings(self, new_num_tokens: int | None = None) -> keras.layers.Embedding | tf.Variable: - """ - Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. - - Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. - - Arguments: - new_num_tokens (`int`, *optional*): - The number of new tokens in the embedding matrix. Increasing the size will add newly initialized - vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just - returns a pointer to the input tokens without doing anything. - - Return: - `tf.Variable` or `keras.layers.Embedding`: Pointer to the input tokens of the model. - """ - # TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor - - # Run the new code path if the model has a keras embeddings layer - if isinstance(self.get_input_embeddings(), keras.layers.Embedding): - return self._v2_resized_token_embeddings(new_num_tokens) - - if new_num_tokens is None or new_num_tokens == self.config.vocab_size: - return self._get_word_embedding_weight(self.get_input_embeddings()) - - model_embeds = self._resize_token_embeddings(new_num_tokens) - - # Update base model and current model config - self.config.vocab_size = new_num_tokens - - return model_embeds - - def _v2_resized_token_embeddings(self, new_num_tokens: int | None = None) -> keras.layers.Embedding: - """ - Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. - - Arguments: - new_num_tokens (`int`, *optional*): - The number of new tokens in the embedding matrix. Increasing the size will add newly initialized - vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just - returns a pointer to the input tokens without doing anything. - - Return: - `keras.layers.Embedding`: Pointer to the input tokens of the model. - """ - if new_num_tokens is None or new_num_tokens == self.config.vocab_size: - return self.get_input_embeddings() - - model_embeds = self._v2_resize_token_embeddings(new_num_tokens) - - # Update base model and current model config - self.config.vocab_size = new_num_tokens - - return model_embeds - - def _get_word_embedding_weight(model, embedding_layer): - # TODO (joao): flagged for detection due to embeddings refactor - - # If the variable holds the weights themselves, return them - if isinstance(embedding_layer, tf.Tensor): - return embedding_layer - # Otherwise, try to get them from the layer's attributes - - embeds = getattr(embedding_layer, "weight", None) - if embeds is not None: - return embeds - - embeds = getattr(embedding_layer, "decoder", None) - if embeds is not None: - return embeds - - # The reason why the attributes don't exist might be - # because the model is not built, so retry getting - # the argument after building the model - model.build_in_name_scope() - - embeds = getattr(embedding_layer, "weight", None) - if embeds is not None: - return embeds - - embeds = getattr(embedding_layer, "decoder", None) - if embeds is not None: - return embeds - - return None - - def _resize_token_embeddings(self, new_num_tokens): - # TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor - old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings()) - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) - - # if word embeddings are not tied, make sure that lm head bias is resized as well - if self.get_bias() is not None: - old_lm_head_bias = self.get_bias() - new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens) - - self.set_bias(new_lm_head_bias) - - # if word embeddings are not tied, make sure that lm head decoder is resized as well - if self.get_output_embeddings() is not None: - old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings()) - new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens) - - self.set_output_embeddings(new_lm_head_decoder) - - self.set_input_embeddings(new_embeddings) - - return self.get_input_embeddings() - - def _v2_resize_token_embeddings(self, new_num_tokens): - old_embeddings = self.get_input_embeddings() - new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens) - self.set_input_embeddings(new_embeddings) - - # If word embeddings are not tied, make sure that lm head bias is resized as well - if self.get_bias() is not None: - old_lm_head_bias = self.get_bias() - new_lm_head_bias = self._v2_get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens) - self.set_bias(new_lm_head_bias) - - # If word embeddings are not tied, make sure that lm head decoder is resized as well. - tied_weights = self.get_input_embeddings() == self.get_output_embeddings() - if self.get_output_embeddings() is not None and not tied_weights: - old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings()) - # TODO (joao): this one probably needs a v2 version with other models - new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens) - self.set_output_embeddings(new_lm_head_decoder) - - return self.get_input_embeddings() - - def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens): - """ - Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end. - Reducing the size will remove vectors from the end - - Args: - old_lm_head_bias (`tf.Variable`): - Old lm head bias to be resized. - new_num_tokens (`int`, *optional*): - New number of tokens in the linear matrix. - - Increasing the size will add newly initialized vectors at the end. Reducing the size will remove - vectors from the end. If not provided or `None`, just returns None - - Return: - `tf.Variable`: Pointer to the resized bias. - """ - # TODO (joao): flagged for replacement (by `_v2_get_resized_lm_head_bias`) due to embeddings refactor - new_lm_head_bias = {} - - for attr, weight in old_lm_head_bias.items(): - first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight) - size_diff = new_num_tokens - old_num_tokens - final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens] - - # initialize new bias - if tf.math.greater(size_diff, 0): - padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]] - current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1) - num_tokens_to_copy = min(old_num_tokens, new_num_tokens) - mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy] - bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True) - bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False) - else: - slice_from = [0] if first_dim is None else [0, 0] - current_bias = tf.slice( - weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape) - ) - bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True) - - new_bias = self.add_weight( - shape=final_shape, - initializer="zeros", - trainable=True, - name=weight.name.split(":")[0], - ) - init_bias = tf.where(bias_mask, current_bias, new_bias.value()) - - new_bias.assign(init_bias) - new_lm_head_bias[attr] = new_bias - - return new_lm_head_bias - - def _v2_get_resized_lm_head_bias( - self, old_lm_head_bias: dict[str, tf.Variable], new_num_tokens: int - ) -> dict[str, tf.Tensor]: - """ - Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end. - Reducing the size will remove vectors from the end - - Args: - old_lm_head_bias (`dict[str, tf.Variable]`): - Old lm head bias to be resized. - new_num_tokens (`int`): - New number of tokens in the linear matrix. Increasing the size will add newly initialized vectors at - the end. Reducing the size will remove vectors from the end. - - Return: - `tf.Tensor`: Values for the resized bias. - """ - new_lm_head_bias = {} - - for attr, weight in old_lm_head_bias.items(): - # Determine the size difference (depending on the shape) - first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight) - size_diff = new_num_tokens - old_num_tokens - - # Copy the old bias values to the new bias - if old_num_tokens > new_num_tokens: - new_bias = weight.value()[..., :new_num_tokens] - else: - padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]] - new_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape)) - - new_lm_head_bias[attr] = new_bias - return new_lm_head_bias - - def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens): - """ - Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end. - Reducing the size will remove vectors from the end - - Args: - old_lm_head_decoder (`tf.Variable`): - Old lm head decoder to be resized. - new_num_tokens (`int`, *optional*): - New number of tokens in the linear matrix. - - Increasing the size will add newly initialized vectors at the end. Reducing the size will remove - vectors from the end. If not provided or `None`, just returns None - - Return: - `tf.Variable`: Pointer to the resized decoder or None if the output embeddings are different from the input - ones. - """ - new_lm_head_decoder = old_lm_head_decoder - is_input_output_equals = tf.reduce_any( - self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder - ) - - if old_lm_head_decoder is not None and not is_input_output_equals: - old_embedding_dim = shape_list(old_lm_head_decoder)[1] - decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens) - new_lm_head_decoder = self.add_weight( - shape=(new_num_tokens, old_embedding_dim), - initializer="zeros", - trainable=True, - name=old_lm_head_decoder.name.split(":")[0], - ) - init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value()) - - new_lm_head_decoder.assign(init_decoder) - - return new_lm_head_decoder - - def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable: - """ - Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly - initialized vectors at the end. Reducing the size will remove vectors from the end - - Args: - old_embeddings (`tf.Variable`): - Old embeddings to be resized. - new_num_tokens (`int`, *optional*): - New number of tokens in the embedding matrix. - - Increasing the size will add newly initialized vectors at the end. Reducing the size will remove - vectors from the end. If not provided or `None`, just returns a pointer to the input tokens - `tf.Variable` module of the model without doing anything. - - Return: - `tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is - `None` - """ - # TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor - old_embedding_dim = shape_list(old_embeddings)[1] - init_range = getattr(self.config, "initializer_range", 0.02) - embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens) - new_embeddings = self.add_weight( - name=old_embeddings.name.split(":")[0], - shape=[new_num_tokens, old_embedding_dim], - initializer=get_initializer(init_range), - dtype=tf.float32, - ) - init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value()) - - new_embeddings.assign(init_embeddings) - - return new_embeddings - - def _v2_get_resized_embeddings( - self, old_embeddings: keras.layers.Embedding, new_num_tokens: int - ) -> keras.layers.Embedding: - """ - Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized - vectors at the end. Reducing the size will remove vectors from the end. - - Args: - old_embeddings (`keras.layers.Embedding`): - Old embeddings to be resized. - new_num_tokens (`int`, *optional*): - New number of tokens in the embedding matrix. - - Return: - `keras.layers.Embedding`: Resized Embedding layer. - """ - - # Get the initialization range for the embeddings - init_range = 0.02 # default value - potential_initialization_variable_names = [ - "initializer_range", # most common - "initializer_factor", # e.g. T5 - "init_std", # e.g BART - ] - for var_name in potential_initialization_variable_names: - if hasattr(self.config, var_name): - init_range = getattr(self.config, var_name) - - # Get a new (initialized) embeddings layer - new_embeddings = keras.layers.Embedding( - input_dim=new_num_tokens, - output_dim=old_embeddings.output_dim, - embeddings_initializer=keras.initializers.TruncatedNormal(stddev=init_range), - name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0" - ) - new_embeddings(tf.constant([[0]])) - - # Copy the old embeddings to the new embeddings - if old_embeddings.input_dim >= new_num_tokens: - init_embeddings = old_embeddings.embeddings[:new_num_tokens] - else: - init_embeddings = tf.concat( - [old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0 - ) - new_embeddings.embeddings.assign(init_embeddings) - return new_embeddings - - def prune_heads(self, heads_to_prune): - """ - Prunes heads of the base model. - - Arguments: - heads_to_prune (`dict[int, list[int]]`): - Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads - to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on - layer 1 and heads 2 and 3 on layer 2. - """ - raise NotImplementedError - - def save_pretrained( - self, - save_directory, - saved_model=False, - version=1, - push_to_hub=False, - signatures=None, - max_shard_size: int | str = "5GB", - create_pr: bool = False, - safe_serialization: bool = False, - token: str | bool | None = None, - **kwargs, - ): - """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - [`~TFPreTrainedModel.from_pretrained`] class method. - - Arguments: - save_directory (`str`): - Directory to which to save. Will be created if it doesn't exist. - saved_model (`bool`, *optional*, defaults to `False`): - If the model has to be saved in saved model format as well or not. - version (`int`, *optional*, defaults to 1): - The version of the saved model. A saved model needs to be versioned in order to be properly loaded by - TensorFlow Serving as detailed in the official documentation - https://www.tensorflow.org/tfx/serving/serving_basic - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the - repository you want to push to with `repo_id` (will default to the name of `save_directory` in your - namespace). - signatures (`dict` or `tf.function`, *optional*): - Model's signature used for serving. This will be passed to the `signatures` argument of model.save(). - max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): - The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size - lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). - - - - If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard - which will be bigger than `max_shard_size`. - - - - create_pr (`bool`, *optional*, defaults to `False`): - Whether or not to create a PR with the uploaded files or directly commit. - safe_serialization (`bool`, *optional*, defaults to `False`): - Whether to save the model using `safetensors` or the traditional TensorFlow way (that uses `h5`). - token (`str` or `bool`, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use - the token generated when running `hf auth login` (stored in `~/.huggingface`). - kwargs (`dict[str, Any]`, *optional*): - Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. - """ - use_auth_token = kwargs.pop("use_auth_token", None) - - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if token is not None: - raise ValueError( - "`token` and `use_auth_token` are both specified. Please set only the argument `token`." - ) - token = use_auth_token - - if token is not None: - kwargs["token"] = token - - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - os.makedirs(save_directory, exist_ok=True) - - if push_to_hub: - commit_message = kwargs.pop("commit_message", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id = self._create_repo(repo_id, **kwargs) - files_timestamps = self._get_files_timestamps(save_directory) - - if saved_model: - # If `torch_dtype` is in the config with a torch dtype class as the value, we need to change it to string. - # (Although TF doesn't care about this attribute, we can't just remove it or set it to `None`.) - if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str): - self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1] - if signatures is None: - serving_default = self.serving.get_concrete_function(self.input_signature) - if any(spec.dtype == tf.int32 for spec in self.input_signature.values()): - int64_spec = { - key: tf.TensorSpec( - shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name - ) - for key, spec in self.input_signature.items() - } - int64_serving = self.serving.get_concrete_function(int64_spec) - signatures = {"serving_default": serving_default, "int64_serving": int64_serving} - else: - signatures = serving_default - saved_model_dir = os.path.join(save_directory, "saved_model", str(version)) - self.save(saved_model_dir, include_optimizer=False, signatures=signatures) - logger.info(f"Saved model created in {saved_model_dir}") - - # Save configuration file - self.config.architectures = [self.__class__.__name__[2:]] - - # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be - # loaded from the Hub. - if self._auto_class is not None: - custom_object_save(self, save_directory, config=self.config) - - self.config.save_pretrained(save_directory) - if self.can_generate(): - self.generation_config.save_pretrained(save_directory) - - # If we save using the predefined names, we can load using `from_pretrained` - weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME - output_model_file = os.path.join(save_directory, weights_name) - - shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name) - - # Clean the folder from a previous save - for filename in os.listdir(save_directory): - full_filename = os.path.join(save_directory, filename) - # If we have a shard file that is not going to be replaced, we delete it, but only from the main process - # in distributed settings to avoid race conditions. - weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") - if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards: - os.remove(full_filename) - - if index is None: - if safe_serialization: - state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in self.weights} - safe_save_file(state_dict, output_model_file, metadata={"format": "tf"}) - else: - self.save_weights(output_model_file) - logger.info(f"Model weights saved in {output_model_file}") - else: - save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME - save_index_file = os.path.join(save_directory, save_index_file) - # Save the index as well - with open(save_index_file, "w", encoding="utf-8") as index_file: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - index_file.write(content) - logger.info( - f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " - f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) - for shard_file, shard in shards.items(): - if safe_serialization: - shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard} - safe_save_file( - shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"} - ) - else: - with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file: - layers = [] - for layer in sorted(shard, key=lambda x: x.name): - if "model." in layer.name or len(layer.name.split("/")) == 1: - layer_name = layer.name - else: - layer_name = "/".join(layer.name.split("/")[1:]) - param_dset = shard_file.create_dataset( - layer_name, layer.numpy().shape, dtype=layer.numpy().dtype - ) - param_dset[:] = layer.numpy() - layers.append(layer_name.encode("utf8")) - save_attributes_to_hdf5_group(shard_file, "layer_names", layers) - - if push_to_hub: - self._upload_modified_files( - save_directory, - repo_id, - files_timestamps, - commit_message=commit_message, - token=token, - ) - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: str | os.PathLike | None, - *model_args, - config: PretrainedConfig | str | os.PathLike | None = None, - cache_dir: str | os.PathLike | None = None, - ignore_mismatched_sizes: bool = False, - force_download: bool = False, - local_files_only: bool = False, - token: str | bool | None = None, - revision: str = "main", - use_safetensors: bool | None = None, - **kwargs, - ): - r""" - Instantiate a pretrained TF 2.0 model from a pre-trained model configuration. - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_name_or_path (`str`, *optional*): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this - case, `from_pt` should be set to `True` and a configuration object should be provided as `config` - argument. This loading path is slower than converting the PyTorch model in a TensorFlow model - using the provided conversion scripts and loading the TensorFlow model afterwards. - - `None` if you are both providing the configuration and state dictionary (resp. with keyword - arguments `config` and `state_dict`). - model_args (sequence of positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - config (`Union[PretrainedConfig, str]`, *optional*): - Can be either: - - - an instance of a class derived from [`PretrainedConfig`], - - a string valid as input to [`~PretrainedConfig.from_pretrained`]. - - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - The model is a model provided by the library (loaded with the *model id* string of a pretrained - model). - - The model was saved using [`~TFPreTrainedModel.save_pretrained`] and is reloaded by supplying the - save directory. - - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a - configuration JSON file named *config.json* is found in the directory. - from_pt (`bool`, *optional*, defaults to `False`): - Load the model weights from a PyTorch state_dict save file (see docstring of - `pretrained_model_name_or_path` argument). - ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): - Whether or not to raise an error if some of the weights from the checkpoint do not have the same size - as the weights of the model (if for instance, you are instantiating a model with 10 labels from a - checkpoint with 3 labels). - cache_dir (`str`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download: - Deprecated and ignored. All downloads are now resumed by default when possible. - Will be removed in v5 of Transformers. - proxies: - (`dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., - `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): Whether ot not to also return a - dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (e.g., not try downloading the model). - token (`str` or `bool`, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use - the token generated when running `hf auth login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - - - - - To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. - - - - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can - specify the folder name here. - tf_to_pt_weight_rename (`Callable`, *optional*): - A function that is called to transform the names of weights during the PyTorch to TensorFlow - crossloading process. This is not necessary for most models, but is useful to allow composite models to - be crossloaded correctly. - use_safetensors (`bool`, *optional*, defaults to `None`): - Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` - is not installed, it will be set to `False`. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). Behaves differently depending on whether a `config` is provided or - automatically loaded: - - - If a configuration is provided with `config`, `**kwargs` will be directly passed to the - underlying model's `__init__` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that - corresponds to a configuration attribute will be used to override said attribute with the - supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute - will be passed to the underlying model's `__init__` function. - - Examples: - - ```python - >>> from transformers import BertConfig, TFBertModel - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased") - >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). - >>> model = TFBertModel.from_pretrained("./test/saved_model/") - >>> # Update configuration during loading. - >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True) - >>> assert model.config.output_attentions == True - >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable). - >>> config = BertConfig.from_json_file("./pt_model/my_pt_model_config.json") - >>> model = TFBertModel.from_pretrained("./pt_model/my_pytorch_model.bin", from_pt=True, config=config) - ```""" - from_pt = kwargs.pop("from_pt", False) - resume_download = kwargs.pop("resume_download", None) - proxies = kwargs.pop("proxies", None) - output_loading_info = kwargs.pop("output_loading_info", False) - use_auth_token = kwargs.pop("use_auth_token", None) - trust_remote_code = kwargs.pop("trust_remote_code", None) - _ = kwargs.pop("mirror", None) - load_weight_prefix = kwargs.pop("load_weight_prefix", None) - from_pipeline = kwargs.pop("_from_pipeline", None) - from_auto_class = kwargs.pop("_from_auto", False) - subfolder = kwargs.pop("subfolder", "") - commit_hash = kwargs.pop("_commit_hash", None) - tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None) - - # Not relevant for TF models - _ = kwargs.pop("adapter_kwargs", None) - - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if token is not None: - raise ValueError( - "`token` and `use_auth_token` are both specified. Please set only the argument `token`." - ) - token = use_auth_token - - if trust_remote_code is True: - logger.warning( - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" - " ignored." - ) - - user_agent = {"file_type": "model", "framework": "tensorflow", "from_auto_class": from_auto_class} - if from_pipeline is not None: - user_agent["using_pipeline"] = from_pipeline - - if is_offline_mode() and not local_files_only: - logger.info("Offline mode: forcing local_files_only=True") - local_files_only = True - - if use_safetensors is None and not is_safetensors_available(): - use_safetensors = False - - # Load config if we don't provide a configuration - if not isinstance(config, PretrainedConfig): - config_path = config if config is not None else pretrained_model_name_or_path - config, model_kwargs = cls.config_class.from_pretrained( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - _from_auto=from_auto_class, - _from_pipeline=from_pipeline, - _commit_hash=commit_hash, - **kwargs, - ) - else: - model_kwargs = kwargs - - if commit_hash is None: - commit_hash = getattr(config, "_commit_hash", None) - - # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the - # index of the files. - is_sharded = False - # Load model - if pretrained_model_name_or_path is not None: - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - is_local = os.path.isdir(pretrained_model_name_or_path) - if is_local: - if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): - # Load from a PyTorch checkpoint in priority if from_pt - archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) - elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): - # Load from a sharded PyTorch checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) - is_sharded = True - elif use_safetensors is not False and os.path.isfile( - os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) - ): - # Load from a safetensors checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) - elif use_safetensors is not False and os.path.isfile( - os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) - ): - # Load from a sharded safetensors checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) - is_sharded = True - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): - # Load from a TF 2.0 checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)): - # Load from a sharded TF 2.0 checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) - is_sharded = True - - # At this stage we don't have a weight file so we will raise an error. - elif use_safetensors: - raise OSError( - f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. " - f"Please make sure that the model has been saved with `safe_serialization=True` or do not " - f"set `use_safetensors=True`." - ) - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile( - os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) - ): - raise OSError( - f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " - "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " - "weights." - ) - else: - raise OSError( - f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " - f"{pretrained_model_name_or_path}." - ) - elif os.path.isfile(pretrained_model_name_or_path): - archive_file = pretrained_model_name_or_path - is_local = True - elif os.path.isfile(pretrained_model_name_or_path + ".index"): - archive_file = pretrained_model_name_or_path + ".index" - is_local = True - elif is_remote_url(pretrained_model_name_or_path): - filename = pretrained_model_name_or_path - resolved_archive_file = download_url(pretrained_model_name_or_path) - else: - # set correct filename - if from_pt: - filename = WEIGHTS_NAME - elif use_safetensors is not False: - filename = SAFE_WEIGHTS_NAME - else: - filename = TF2_WEIGHTS_NAME - - try: - # Load from URL or cache if already cached - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "resume_download": resume_download, - "local_files_only": local_files_only, - "token": token, - "user_agent": user_agent, - "revision": revision, - "subfolder": subfolder, - "_raise_exceptions_for_gated_repo": False, - "_raise_exceptions_for_missing_entries": False, - "_commit_hash": commit_hash, - } - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - - # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None - # result when internet is up, the repo and revision exist, but the file does not. - if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME: - # Did not find the safetensors file, let's fallback to TF. - # No support for sharded safetensors yet, so we'll raise an error if that's all we find. - filename = TF2_WEIGHTS_NAME - resolved_archive_file = cached_file( - pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs - ) - if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME: - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - resolved_archive_file = cached_file( - pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs - ) - if resolved_archive_file is not None: - is_sharded = True - if resolved_archive_file is None and filename == WEIGHTS_NAME: - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - resolved_archive_file = cached_file( - pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs - ) - if resolved_archive_file is not None: - is_sharded = True - if resolved_archive_file is None: - # Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error - # message. - has_file_kwargs = { - "revision": revision, - "proxies": proxies, - "token": token, - "cache_dir": cache_dir, - "local_files_only": local_files_only, - } - if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): - is_sharded = True - elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): - raise OSError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" - " load this model from those weights." - ) - else: - raise OSError( - f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}," - f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}" - ) - - except OSError: - # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted - # to the original exception. - raise - except Exception: - # For any other exception, we throw a generic error. - - raise OSError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" - " from 'https://huggingface.co/models', make sure you don't have a local directory with the" - f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}" - ) - if is_local: - logger.info(f"loading weights file {archive_file}") - resolved_archive_file = archive_file - filename = resolved_archive_file.split(os.path.sep)[-1] - else: - logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") - else: - resolved_archive_file = None - - # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. - if is_sharded: - # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - pretrained_model_name_or_path, - resolved_archive_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - _commit_hash=commit_hash, - ) - - safetensors_from_pt = False - if filename == SAFE_WEIGHTS_NAME: - with safe_open(resolved_archive_file, framework="tf") as f: - safetensors_metadata = f.metadata() - if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: - raise OSError( - f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." - " Make sure you save your model with the `save_pretrained` method." - ) - safetensors_from_pt = safetensors_metadata.get("format") == "pt" - elif filename == SAFE_WEIGHTS_INDEX_NAME: - with safe_open(resolved_archive_file[0], framework="tf") as f: - safetensors_metadata = f.metadata() - if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: - raise OSError( - f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." - " Make sure you save your model with the `save_pretrained` method." - ) - safetensors_from_pt = safetensors_metadata.get("format") == "pt" - - config.name_or_path = pretrained_model_name_or_path - - # composed models, *e.g.* TFRag, require special treatment when it comes to loading - # pre-trained weights. - if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None: - model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name") - - # Instantiate model. - model = cls(config, *model_args, **model_kwargs) - - if tf_to_pt_weight_rename is None and hasattr(model, "tf_to_pt_weight_rename"): - # TODO Matt: This is a temporary workaround to allow weight renaming, but requires a method - # to be defined for each class that requires a rename. We can probably just have a class-level - # dict and a single top-level method or something and cut down a lot of boilerplate code - tf_to_pt_weight_rename = model.tf_to_pt_weight_rename - - if from_pt: - from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model - - # Load from a PyTorch checkpoint - return load_pytorch_checkpoint_in_tf2_model( - model, - resolved_archive_file, - allow_missing_keys=True, - output_loading_info=output_loading_info, - _prefix=load_weight_prefix, - tf_to_pt_weight_rename=tf_to_pt_weight_rename, - ) - - # we might need to extend the variable scope for composite models - if load_weight_prefix is not None: - with tf.compat.v1.variable_scope(load_weight_prefix): - model.build_in_name_scope() # build the network with dummy inputs - else: - model.build_in_name_scope() # build the network with dummy inputs - - if safetensors_from_pt and not is_sharded: - from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model - - with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: - # Load from a PyTorch safetensors checkpoint - # We load in TF format here because PT weights often need to be transposed, and this is much - # faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times. - return load_pytorch_state_dict_in_tf2_model( - model, - safetensors_archive, - tf_inputs=False, # No need to build the model again - allow_missing_keys=True, - output_loading_info=output_loading_info, - _prefix=load_weight_prefix, - ignore_mismatched_sizes=ignore_mismatched_sizes, - tf_to_pt_weight_rename=tf_to_pt_weight_rename, - ) - elif safetensors_from_pt: - from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model - - return load_sharded_pytorch_safetensors_in_tf2_model( - model, - resolved_archive_file, - tf_inputs=False, - allow_missing_keys=True, - output_loading_info=output_loading_info, - _prefix=load_weight_prefix, - ignore_mismatched_sizes=ignore_mismatched_sizes, - tf_to_pt_weight_rename=tf_to_pt_weight_rename, - ) - - # 'by_name' allow us to do transfer learning by skipping/adding layers - # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 - try: - if is_sharded: - for file in resolved_archive_file: - os.path.isfile(file), f"Error retrieving files {file}" - if filename == SAFE_WEIGHTS_INDEX_NAME: - missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors( - model, - resolved_archive_file, - ignore_mismatched_sizes=ignore_mismatched_sizes, - _prefix=load_weight_prefix, - ) - else: - missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights( - model, - resolved_archive_file, - ignore_mismatched_sizes=ignore_mismatched_sizes, - _prefix=load_weight_prefix, - ) - else: - # Handles both H5 and safetensors - missing_keys, unexpected_keys, mismatched_keys = load_tf_weights( - model, - resolved_archive_file, - ignore_mismatched_sizes=ignore_mismatched_sizes, - _prefix=load_weight_prefix, - ) - except OSError as e: - try: - with open(resolved_archive_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please install " - "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " - "you cloned." - ) - else: - raise ValueError from e - except (UnicodeDecodeError, ValueError): - raise OSError( - "Unable to load weights from h5 file. " - "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " - ) - - if cls._keys_to_ignore_on_load_missing is not None: - for pat in cls._keys_to_ignore_on_load_missing: - missing_keys = [k for k in missing_keys if re.search(pat, k) is None] - - if cls._keys_to_ignore_on_load_unexpected is not None: - for pat in cls._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when" - f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" - f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" - " with another architecture (e.g. initializing a BertForSequenceClassification model from a" - " BertForPreTraining model).\n- This IS NOT expected if you are initializing" - f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" - " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." - ) - else: - logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n") - - if len(missing_keys) > 0: - logger.warning( - f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" - " TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - elif len(mismatched_keys) == 0: - logger.warning( - f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at" - f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" - f" was trained on, you can already use {model.__class__.__name__} for predictions without further" - " training." - ) - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys - ] - ) - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" - f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" - " to use it for predictions and inference." - ) - - # If it is a model with generation capabilities, attempt to load the generation config - if model.can_generate(): - try: - model.generation_config = GenerationConfig.from_pretrained( - pretrained_model_name_or_path, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - _from_auto=from_auto_class, - _from_pipeline=from_pipeline, - **kwargs, - ) - except OSError: - logger.info( - "Generation config file not found, using a generation config created from the model config." - ) - pass - - if output_loading_info: - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - } - - return model, loading_info - - return model - - def push_to_hub( - self, - repo_id: str, - use_temp_dir: bool | None = None, - commit_message: str | None = None, - private: bool | None = None, - max_shard_size: int | str | None = "10GB", - token: bool | str | None = None, - # (`use_auth_token` is deprecated: we have to keep it here as we don't have **kwargs) - use_auth_token: bool | str | None = None, - create_pr: bool = False, - **base_model_card_args, - ) -> str: - """ - Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`. - - Parameters: - repo_id (`str`): - The name of the repository you want to push your model to. It should contain your organization name - when pushing to a given organization. - use_temp_dir (`bool`, *optional*): - Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub. - Will default to `True` if there is no directory named like `repo_id`, `False` otherwise. - commit_message (`str`, *optional*): - Message to commit while pushing. Will default to `"Upload model"`. - private (`bool`, *optional*): - Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. - token (`bool` or `str`, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `hf auth login` (stored in `~/.huggingface`). Will default to `True` if `repo_url` - is not specified. - max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): - Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard - will then be each of size lower than this size. If expressed as a string, needs to be digits followed - by a unit (like `"5MB"`). - create_pr (`bool`, *optional*, defaults to `False`): - Whether or not to create a PR with the uploaded files or directly commit. - - Examples: - - ```python - from transformers import TFAutoModel - - model = TFAutoModel.from_pretrained("google-bert/bert-base-cased") - - # Push the model to your namespace with the name "my-finetuned-bert". - model.push_to_hub("my-finetuned-bert") - - # Push the model to an organization with the name "my-finetuned-bert". - model.push_to_hub("huggingface/my-finetuned-bert") - ``` - """ - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if token is not None: - raise ValueError( - "`token` and `use_auth_token` are both specified. Please set only the argument `token`." - ) - token = use_auth_token - - if "repo_path_or_name" in base_model_card_args: - warnings.warn( - "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use " - "`repo_id` instead." - ) - repo_id = base_model_card_args.pop("repo_path_or_name") - # Deprecation warning will be sent after for repo_url and organization - repo_url = base_model_card_args.pop("repo_url", None) - organization = base_model_card_args.pop("organization", None) - - if os.path.isdir(repo_id): - working_dir = repo_id - repo_id = repo_id.split(os.path.sep)[-1] - else: - working_dir = repo_id.split("/")[-1] - - repo_id = self._create_repo( - repo_id, private=private, token=token, repo_url=repo_url, organization=organization - ) - - if use_temp_dir is None: - use_temp_dir = not os.path.isdir(working_dir) - - with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir: - files_timestamps = self._get_files_timestamps(work_dir) - - # Save all files. - self.save_pretrained(work_dir, max_shard_size=max_shard_size) - if hasattr(self, "history") and hasattr(self, "create_model_card"): - # This is a Keras model and we might be able to fish out its History and make a model card out of it - base_model_card_args = { - "output_dir": work_dir, - "model_name": Path(repo_id).name, - } - base_model_card_args.update(base_model_card_args) - self.create_model_card(**base_model_card_args) - - self._upload_modified_files( - work_dir, - repo_id, - files_timestamps, - commit_message=commit_message, - token=token, - create_pr=create_pr, - ) - - @classmethod - def register_for_auto_class(cls, auto_class="TFAutoModel"): - """ - Register this class with a given auto class. This should only be used for custom models as the ones in the - library are already mapped with an auto class. - - - - Args: - auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`): - The auto class to register this new model with. - """ - if not isinstance(auto_class, str): - auto_class = auto_class.__name__ - - import transformers.models.auto as auto_module - - if not hasattr(auto_module, auto_class): - raise ValueError(f"{auto_class} is not a valid auto class.") - - cls._auto_class = auto_class - - -class TFConv1D(keras.layers.Layer): - """ - 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). - - Basically works like a linear layer but the weights are transposed. - - Args: - nf (`int`): - The number of output features. - nx (`int`): - The number of input features. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation to use to initialize the weights. - kwargs (`dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`. - """ - - def __init__(self, nf, nx, initializer_range=0.02, **kwargs): - super().__init__(**kwargs) - self.nf = nf - self.nx = nx - self.initializer_range = initializer_range - - def build(self, input_shape): - if self.built: - return - self.built = True - self.weight = self.add_weight( - "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range) - ) - self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer()) - - def call(self, x): - bz, sl = shape_list(x)[:2] - - x = tf.reshape(x, [-1, self.nx]) - x = tf.matmul(x, self.weight) + self.bias - - x = tf.reshape(x, [bz, sl, self.nf]) - - return x - - -class TFSharedEmbeddings(keras.layers.Layer): - r""" - Construct shared token embeddings. - - The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language - modeling. - - Args: - vocab_size (`int`): - The size of the vocabulary, e.g., the number of unique tokens. - hidden_size (`int`): - The size of the embedding vectors. - initializer_range (`float`, *optional*): - The standard deviation to use when initializing the weights. If no value is provided, it will default to - \\(1/\sqrt{hidden\_size}\\). - kwargs (`dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`. - """ - - # TODO (joao): flagged for detection due to embeddings refactor - - def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float | None = None, **kwargs): - super().__init__(**kwargs) - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range - warnings.warn( - "`TFSharedEmbeddings` is scheduled for deletion in v4.32, use `keras.layers.Embedding` instead.", - DeprecationWarning, - ) - - def build(self, input_shape): - """ - Build shared token embedding layer Shared weights logic adapted from - https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 - """ - self.weight = self.add_weight( - "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range) - ) - super().build(input_shape) - - def get_config(self): - config = { - "vocab_size": self.vocab_size, - "hidden_size": self.hidden_size, - "initializer_range": self.initializer_range, - } - base_config = super().get_config() - - return dict(list(base_config.items()) + list(config.items())) - - def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor: - """ - Get token embeddings of inputs or decode final hidden state. - - Args: - inputs (`tf.Tensor`): - In embedding mode, should be an int64 tensor with shape `[batch_size, length]`. - - In linear mode, should be a float tensor with shape `[batch_size, length, hidden_size]`. - mode (`str`, defaults to `"embedding"`): - A valid value is either `"embedding"` or `"linear"`, the first one indicates that the layer should be - used as an embedding layer, the second one that the layer should be used as a linear decoder. - - Returns: - `tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape `[batch_size, length, - embedding_size]`. - - In linear mode, the output is a float32 with shape `[batch_size, length, vocab_size]`. - - Raises: - ValueError: if `mode` is not valid. - - Shared weights logic is adapted from - [here](https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24). - """ - if mode == "embedding": - return self._embedding(inputs) - elif mode == "linear": - return self._linear(inputs) - else: - raise ValueError(f"mode {mode} is not valid.") - - def _embedding(self, input_ids): - """Applies embedding based on inputs tensor.""" - return tf.gather(self.weight, input_ids) - - def _linear(self, inputs): - """ - Computes logits by running inputs through a linear layer. - - Args: - inputs: A float32 tensor with shape [..., hidden_size] - - Returns: - float32 tensor with shape [..., vocab_size]. - """ - first_dims = shape_list(inputs)[:-1] - x = tf.reshape(inputs, [-1, self.hidden_size]) - logits = tf.matmul(x, self.weight, transpose_b=True) - - return tf.reshape(logits, first_dims + [self.vocab_size]) - - -class TFSequenceSummary(keras.layers.Layer): - """ - Compute a single vector summary of a sequence hidden states. - - Args: - config ([`PretrainedConfig`]): - The config used by the model. Relevant arguments in the config class of the model are (refer to the actual - config class of your model for the default values it uses): - - - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: - - - `"last"` -- Take the last token hidden state (like XLNet) - - `"first"` -- Take the first token hidden state (like Bert) - - `"mean"` -- Take the mean of all tokens hidden states - - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) - - `"attn"` -- Not implemented now, use multi-head attention - - - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. - - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes - (otherwise to `config.hidden_size`). - - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, - another string or `None` will add no activation. - - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. - - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. - - initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation to use to initialize the weights. - kwargs (`dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`. - """ - - def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs): - super().__init__(**kwargs) - - self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last" - if self.summary_type == "attn": - # We should use a standard multi-head attention module with absolute positional embedding for that. - # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 - # We can probably just use the multi-head attention module of PyTorch >=1.1.0 - raise NotImplementedError - - self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj - if self.has_summary: - if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: - num_classes = config.num_labels - else: - num_classes = config.hidden_size - self.summary = keras.layers.Dense( - num_classes, kernel_initializer=get_initializer(initializer_range), name="summary" - ) - - self.has_activation = False - activation_string = getattr(config, "summary_activation", None) - if activation_string is not None: - self.has_activation = True - self.activation = get_tf_activation(activation_string) - - self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0 - if self.has_first_dropout: - self.first_dropout = keras.layers.Dropout(config.summary_first_dropout) - - self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0 - if self.has_last_dropout: - self.last_dropout = keras.layers.Dropout(config.summary_last_dropout) - self.hidden_size = config.hidden_size - - def call(self, inputs, cls_index=None, training=False): - if not isinstance(inputs, (dict, tuple, list)): - hidden_states = inputs - elif isinstance(inputs, (tuple, list)): - hidden_states = inputs[0] - cls_index = inputs[1] if len(inputs) > 1 else None - assert len(inputs) <= 2, "Too many inputs." - else: - hidden_states = inputs.get("hidden_states") - cls_index = inputs.get("cls_index", None) - - if self.summary_type == "last": - output = hidden_states[:, -1] - elif self.summary_type == "first": - output = hidden_states[:, 0] - elif self.summary_type == "mean": - output = tf.reduce_mean(hidden_states, axis=1) - elif self.summary_type == "cls_index": - hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims] - if cls_index is None: - cls_index = tf.fill( - hidden_shape[:-2], hidden_shape[-2] - 1 - ) # A tensor full of shape [batch] or [batch, num choices] full of sequence length - cls_shape = shape_list(cls_index) - if len(cls_shape) <= len(hidden_shape) - 2: - cls_index = tf.expand_dims(cls_index, axis=-1) - # else: - # cls_index = cls_index[..., tf.newaxis] - # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),)) - # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states - output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2) - output = tf.squeeze( - output, axis=len(hidden_shape) - 2 - ) # shape of output: (batch, num choices, hidden_size) - elif self.summary_type == "attn": - raise NotImplementedError - - if self.has_first_dropout: - output = self.first_dropout(output, training=training) - - if self.has_summary: - output = self.summary(output) - - if self.has_activation: - output = self.activation(output) - - if self.has_last_dropout: - output = self.last_dropout(output, training=training) - - return output - - def build(self, input_shape): - if self.built: - return - self.built = True - if getattr(self, "summary", None) is not None: - with tf.name_scope("summary"): - self.summary.build(self.hidden_size) - - -def get_initializer(initializer_range: float = 0.02) -> keras.initializers.TruncatedNormal: - """ - Creates a `keras.initializers.TruncatedNormal` with the given range. - - Args: - initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range. - - Returns: - `keras.initializers.TruncatedNormal`: The truncated normal initializer. - """ - return keras.initializers.TruncatedNormal(stddev=initializer_range) diff --git a/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index df2a22610187..000000000000 --- a/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,62 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert ALBERT checkpoint.""" - -import argparse - -import torch - -from ...utils import logging -from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path): - # Initialise PyTorch model - config = AlbertConfig.from_json_file(albert_config_file) - print(f"Building PyTorch model from configuration: {config}") - model = AlbertForPreTraining(config) - - # Load weights from tf checkpoint - load_tf_weights_in_albert(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - torch.save(model.state_dict(), pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--albert_config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained ALBERT model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/albert/modeling_flax_albert.py b/src/transformers/models/albert/modeling_flax_albert.py deleted file mode 100644 index f2f19cb27716..000000000000 --- a/src/transformers/models/albert/modeling_flax_albert.py +++ /dev/null @@ -1,1132 +0,0 @@ -# coding=utf-8 -# Copyright 2021 Google AI, Google Brain and the HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Optional - -import flax -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPooling, - FlaxMaskedLMOutput, - FlaxMultipleChoiceModelOutput, - FlaxQuestionAnsweringModelOutput, - FlaxSequenceClassifierOutput, - FlaxTokenClassifierOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_albert import AlbertConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "albert/albert-base-v2" -_CONFIG_FOR_DOC = "AlbertConfig" - - -@flax.struct.dataclass -class FlaxAlbertForPreTrainingOutput(ModelOutput): - """ - Output type of [`FlaxAlbertForPreTraining`]. - - Args: - prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - sop_logits (`jnp.ndarray` of shape `(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - prediction_logits: jnp.ndarray = None - sop_logits: jnp.ndarray = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -ALBERT_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`AlbertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -ALBERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - -""" - - -class FlaxAlbertEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" - - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.config.embedding_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.position_embeddings = nn.Embed( - self.config.max_position_embeddings, - self.config.embedding_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.token_type_embeddings = nn.Embed( - self.config.type_vocab_size, - self.config.embedding_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True): - # Embed - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) - position_embeds = self.position_embeddings(position_ids.astype("i4")) - token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) - - # Sum all embeddings - hidden_states = inputs_embeds + token_type_embeddings + position_embeds - - # Layer Norm - hidden_states = self.LayerNorm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxAlbertSelfAttention(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - if self.config.hidden_size % self.config.num_attention_heads != 0: - raise ValueError( - "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " - " : {self.config.num_attention_heads}" - ) - - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): - head_dim = self.config.hidden_size // self.config.num_attention_heads - - query_states = self.query(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - value_states = self.value(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - key_states = self.key(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - projected_attn_output = self.dense(attn_output) - projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic) - layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states) - outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,) - return outputs - - -class FlaxAlbertLayer(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype) - self.ffn = nn.Dense( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.activation = ACT2FN[self.config.hidden_act] - self.ffn_output = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - ): - attention_outputs = self.attention( - hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions - ) - attention_output = attention_outputs[0] - ffn_output = self.ffn(attention_output) - ffn_output = self.activation(ffn_output) - ffn_output = self.ffn_output(ffn_output) - ffn_output = self.dropout(ffn_output, deterministic=deterministic) - hidden_states = self.full_layer_layer_norm(ffn_output + attention_output) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[1],) - return outputs - - -class FlaxAlbertLayerCollection(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num) - ] - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - ): - layer_hidden_states = () - layer_attentions = () - - for layer_index, albert_layer in enumerate(self.layers): - layer_output = albert_layer( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - ) - hidden_states = layer_output[0] - - if output_attentions: - layer_attentions = layer_attentions + (layer_output[1],) - - if output_hidden_states: - layer_hidden_states = layer_hidden_states + (hidden_states,) - - outputs = (hidden_states,) - if output_hidden_states: - outputs = outputs + (layer_hidden_states,) - if output_attentions: - outputs = outputs + (layer_attentions,) - return outputs # last-layer hidden state, (layer hidden states), (layer attentions) - - -class FlaxAlbertLayerCollections(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - layer_index: Optional[str] = None - - def setup(self): - self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - ): - outputs = self.albert_layers( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - return outputs - - -class FlaxAlbertLayerGroups(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_groups) - ] - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = (hidden_states,) if output_hidden_states else None - - for i in range(self.config.num_hidden_layers): - # Index of the hidden group - group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) - layer_group_output = self.layers[group_idx]( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - hidden_states = layer_group_output[0] - - if output_attentions: - all_attentions = all_attentions + layer_group_output[-1] - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -class FlaxAlbertEncoder(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.embedding_hidden_mapping_in = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - hidden_states = self.embedding_hidden_mapping_in(hidden_states) - return self.albert_layer_groups( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - -class FlaxAlbertOnlyMLMHead(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype) - self.activation = ACT2FN[self.config.hidden_act] - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) - self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) - - def __call__(self, hidden_states, shared_embedding=None): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - - if shared_embedding is not None: - hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - hidden_states = self.decoder(hidden_states) - - hidden_states += self.bias - return hidden_states - - -class FlaxAlbertSOPHead(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dropout = nn.Dropout(self.config.classifier_dropout_prob) - self.classifier = nn.Dense(2, dtype=self.dtype) - - def __call__(self, pooled_output, deterministic=True): - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) - return logits - - -class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = AlbertConfig - base_model_prefix = "albert" - module_class: nn.Module = None - - def __init__( - self, - config: AlbertConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - token_type_ids = jnp.zeros_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - attention_mask = jnp.ones_like(input_ids) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # init input tensors if not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(token_type_ids, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) - - -class FlaxAlbertModule(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - add_pooling_layer: bool = True - - def setup(self): - self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxAlbertEncoder(self.config, dtype=self.dtype) - if self.add_pooling_layer: - self.pooler = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - name="pooler", - ) - self.pooler_activation = nn.tanh - else: - self.pooler = None - self.pooler_activation = None - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids: Optional[np.ndarray] = None, - position_ids: Optional[np.ndarray] = None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # make sure `token_type_ids` is correctly initialized when not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - # make sure `position_ids` is correctly initialized when not passed - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic) - - outputs = self.encoder( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - if self.add_pooling_layer: - pooled = self.pooler(hidden_states[:, 0]) - pooled = self.pooler_activation(pooled) - else: - pooled = None - - if not return_dict: - # if pooled is None, don't return it - if pooled is None: - return (hidden_states,) + outputs[1:] - return (hidden_states, pooled) + outputs[1:] - - return FlaxBaseModelOutputWithPooling( - last_hidden_state=hidden_states, - pooler_output=pooled, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.", - ALBERT_START_DOCSTRING, -) -class FlaxAlbertModel(FlaxAlbertPreTrainedModel): - module_class = FlaxAlbertModule - - -append_call_sample_docstring(FlaxAlbertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC) - - -class FlaxAlbertForPreTrainingModule(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) - self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype) - self.sop_classifier = FlaxAlbertSOPHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.albert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if self.config.tie_word_embeddings: - shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - hidden_states = outputs[0] - pooled_output = outputs[1] - - prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding) - sop_scores = self.sop_classifier(pooled_output, deterministic=deterministic) - - if not return_dict: - return (prediction_scores, sop_scores) + outputs[2:] - - return FlaxAlbertForPreTrainingOutput( - prediction_logits=prediction_scores, - sop_logits=sop_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a - `sentence order prediction (classification)` head. - """, - ALBERT_START_DOCSTRING, -) -class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel): - module_class = FlaxAlbertForPreTrainingModule - - -FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxAlbertForPreTraining - - >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2") - >>> model = FlaxAlbertForPreTraining.from_pretrained("albert/albert-base-v2") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.prediction_logits - >>> seq_relationship_logits = outputs.sop_logits - ``` -""" - -overwrite_call_docstring( - FlaxAlbertForPreTraining, - ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING, -) -append_replace_return_docstrings( - FlaxAlbertForPreTraining, output_type=FlaxAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC -) - - -class FlaxAlbertForMaskedLMModule(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.albert = FlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) - self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.albert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.predictions(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxMaskedLMOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING) -class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel): - module_class = FlaxAlbertForMaskedLMModule - - -append_call_sample_docstring( - FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC, revision="refs/pr/11" -) - - -class FlaxAlbertForSequenceClassificationModule(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) - classifier_dropout = ( - self.config.classifier_dropout_prob - if self.config.classifier_dropout_prob is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(rate=classifier_dropout) - self.classifier = nn.Dense( - self.config.num_labels, - dtype=self.dtype, - ) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.albert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) - - if not return_dict: - return (logits,) + outputs[2:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled - output) e.g. for GLUE tasks. - """, - ALBERT_START_DOCSTRING, -) -class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel): - module_class = FlaxAlbertForSequenceClassificationModule - - -append_call_sample_docstring( - FlaxAlbertForSequenceClassification, - _CHECKPOINT_FOR_DOC, - FlaxSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxAlbertForMultipleChoiceModule(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.classifier = nn.Dense(1, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - num_choices = input_ids.shape[1] - input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None - token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None - position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None - - # Model - outputs = self.albert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) - - reshaped_logits = logits.reshape(-1, num_choices) - - if not return_dict: - return (reshaped_logits,) + outputs[2:] - - return FlaxMultipleChoiceModelOutput( - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - ALBERT_START_DOCSTRING, -) -class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel): - module_class = FlaxAlbertForMultipleChoiceModule - - -overwrite_call_docstring( - FlaxAlbertForMultipleChoice, ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") -) -append_call_sample_docstring( - FlaxAlbertForMultipleChoice, - _CHECKPOINT_FOR_DOC, - FlaxMultipleChoiceModelOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxAlbertForTokenClassificationModule(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) - classifier_dropout = ( - self.config.classifier_dropout_prob - if self.config.classifier_dropout_prob is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(rate=classifier_dropout) - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.albert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.classifier(hidden_states) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxTokenClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - ALBERT_START_DOCSTRING, -) -class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel): - module_class = FlaxAlbertForTokenClassificationModule - - -append_call_sample_docstring( - FlaxAlbertForTokenClassification, - _CHECKPOINT_FOR_DOC, - FlaxTokenClassifierOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxAlbertForQuestionAnsweringModule(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) - self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.albert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - logits = self.qa_outputs(hidden_states) - start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - return (start_logits, end_logits) + outputs[1:] - - return FlaxQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - ALBERT_START_DOCSTRING, -) -class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel): - module_class = FlaxAlbertForQuestionAnsweringModule - - -append_call_sample_docstring( - FlaxAlbertForQuestionAnswering, - _CHECKPOINT_FOR_DOC, - FlaxQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) - -__all__ = [ - "FlaxAlbertPreTrainedModel", - "FlaxAlbertModel", - "FlaxAlbertForPreTraining", - "FlaxAlbertForMaskedLM", - "FlaxAlbertForSequenceClassification", - "FlaxAlbertForMultipleChoice", - "FlaxAlbertForTokenClassification", - "FlaxAlbertForQuestionAnswering", -] diff --git a/src/transformers/models/albert/modeling_tf_albert.py b/src/transformers/models/albert/modeling_tf_albert.py deleted file mode 100644 index 101ab63dc054..000000000000 --- a/src/transformers/models/albert/modeling_tf_albert.py +++ /dev/null @@ -1,1572 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 ALBERT model.""" - -from __future__ import annotations - -import math -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_albert import AlbertConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "albert/albert-base-v2" -_CONFIG_FOR_DOC = "AlbertConfig" - - -class TFAlbertPreTrainingLoss: - """ - Loss function suitable for ALBERT pretraining, that is, the task of pretraining a language model by combining SOP + - MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. - """ - - def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor: - loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) - if self.config.tf_legacy_loss: - # make sure only labels that are not equal to -100 - # are taken into account as loss - masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100) - masked_lm_reduced_logits = tf.boolean_mask( - tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])), - mask=masked_lm_active_loss, - ) - masked_lm_labels = tf.boolean_mask( - tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss - ) - sentence_order_active_loss = tf.not_equal( - tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100 - ) - sentence_order_reduced_logits = tf.boolean_mask( - tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss - ) - sentence_order_label = tf.boolean_mask( - tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss - ) - masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits) - sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits) - masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0])) - masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0) - - return masked_lm_loss + sentence_order_loss - - # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway - unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0]) - # make sure only labels that are not equal to -100 - # are taken into account for the loss computation - lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype) - masked_lm_losses = unmasked_lm_losses * lm_loss_mask - reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask) - - sop_logits = tf.reshape(logits[1], (-1, 2)) - # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway - unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits) - sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype) - - masked_sop_loss = unmasked_sop_loss * sop_loss_mask - reduced_masked_sop_loss = tf.reduce_sum(masked_sop_loss) / tf.reduce_sum(sop_loss_mask) - - return tf.reshape(reduced_masked_lm_loss + reduced_masked_sop_loss, (1,)) - - -class TFAlbertEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config: AlbertConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = config.embedding_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.embedding_size]) - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call - def call( - self, - input_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - past_key_values_length=0, - training: bool = False, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - if input_ids is None and inputs_embeds is None: - raise ValueError("Need to provide either `input_ids` or `input_embeds`.") - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - position_ids = tf.expand_dims( - tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 - ) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -class TFAlbertAttention(keras.layers.Layer): - """Contains the complete attention sublayer, including both dropouts and layer norm.""" - - def __init__(self, config: AlbertConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - self.output_attentions = config.output_attentions - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - # Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993 - self.attention_dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - self.output_dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(input_tensor)[0] - mixed_query_layer = self.query(inputs=input_tensor) - mixed_key_layer = self.key(inputs=input_tensor) - mixed_value_layer = self.value(inputs=input_tensor) - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - context_layer = tf.matmul(attention_probs, value_layer) - context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - context_layer = tf.reshape(tensor=context_layer, shape=(batch_size, -1, self.all_head_size)) - self_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - hidden_states = self_outputs[0] - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.output_dropout(inputs=hidden_states, training=training) - attention_output = self.LayerNorm(inputs=hidden_states + input_tensor) - - # add attentions if we output them - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFAlbertLayer(keras.layers.Layer): - def __init__(self, config: AlbertConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFAlbertAttention(config, name="attention") - self.ffn = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn" - ) - - if isinstance(config.hidden_act, str): - self.activation = get_tf_activation(config.hidden_act) - else: - self.activation = config.hidden_act - - self.ffn_output = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn_output" - ) - self.full_layer_layer_norm = keras.layers.LayerNormalization( - epsilon=config.layer_norm_eps, name="full_layer_layer_norm" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - output_attentions=output_attentions, - training=training, - ) - ffn_output = self.ffn(inputs=attention_outputs[0]) - ffn_output = self.activation(ffn_output) - ffn_output = self.ffn_output(inputs=ffn_output) - ffn_output = self.dropout(inputs=ffn_output, training=training) - hidden_states = self.full_layer_layer_norm(inputs=ffn_output + attention_outputs[0]) - - # add attentions if we output them - outputs = (hidden_states,) + attention_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "ffn", None) is not None: - with tf.name_scope(self.ffn.name): - self.ffn.build([None, None, self.config.hidden_size]) - if getattr(self, "ffn_output", None) is not None: - with tf.name_scope(self.ffn_output.name): - self.ffn_output.build([None, None, self.config.intermediate_size]) - if getattr(self, "full_layer_layer_norm", None) is not None: - with tf.name_scope(self.full_layer_layer_norm.name): - self.full_layer_layer_norm.build([None, None, self.config.hidden_size]) - - -class TFAlbertLayerGroup(keras.layers.Layer): - def __init__(self, config: AlbertConfig, **kwargs): - super().__init__(**kwargs) - - self.albert_layers = [ - TFAlbertLayer(config, name=f"albert_layers_._{i}") for i in range(config.inner_group_num) - ] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - output_hidden_states: bool, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - layer_hidden_states = () if output_hidden_states else None - layer_attentions = () if output_attentions else None - - for layer_index, albert_layer in enumerate(self.albert_layers): - if output_hidden_states: - layer_hidden_states = layer_hidden_states + (hidden_states,) - - layer_output = albert_layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[layer_index], - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_output[0] - - if output_attentions: - layer_attentions = layer_attentions + (layer_output[1],) - - # Add last layer - if output_hidden_states: - layer_hidden_states = layer_hidden_states + (hidden_states,) - - return tuple(v for v in [hidden_states, layer_hidden_states, layer_attentions] if v is not None) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "albert_layers", None) is not None: - for layer in self.albert_layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFAlbertTransformer(keras.layers.Layer): - def __init__(self, config: AlbertConfig, **kwargs): - super().__init__(**kwargs) - - self.num_hidden_layers = config.num_hidden_layers - self.num_hidden_groups = config.num_hidden_groups - # Number of layers in a hidden group - self.layers_per_group = int(config.num_hidden_layers / config.num_hidden_groups) - self.embedding_hidden_mapping_in = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="embedding_hidden_mapping_in", - ) - self.albert_layer_groups = [ - TFAlbertLayerGroup(config, name=f"albert_layer_groups_._{i}") for i in range(config.num_hidden_groups) - ] - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states) - all_attentions = () if output_attentions else None - all_hidden_states = (hidden_states,) if output_hidden_states else None - - for i in range(self.num_hidden_layers): - # Index of the hidden group - group_idx = int(i / (self.num_hidden_layers / self.num_hidden_groups)) - layer_group_output = self.albert_layer_groups[group_idx]( - hidden_states=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[group_idx * self.layers_per_group : (group_idx + 1) * self.layers_per_group], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - training=training, - ) - hidden_states = layer_group_output[0] - - if output_attentions: - all_attentions = all_attentions + layer_group_output[-1] - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embedding_hidden_mapping_in", None) is not None: - with tf.name_scope(self.embedding_hidden_mapping_in.name): - self.embedding_hidden_mapping_in.build([None, None, self.config.embedding_size]) - if getattr(self, "albert_layer_groups", None) is not None: - for layer in self.albert_layer_groups: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFAlbertPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = AlbertConfig - base_model_prefix = "albert" - - -class TFAlbertMLMHead(keras.layers.Layer): - def __init__(self, config: AlbertConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = config.embedding_size - self.dense = keras.layers.Dense( - config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - if isinstance(config.hidden_act, str): - self.activation = get_tf_activation(config.hidden_act) - else: - self.activation = config.hidden_act - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - self.decoder_bias = self.add_weight( - shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias" - ) - - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.embedding_size]) - - def get_output_embeddings(self) -> keras.layers.Layer: - return self.decoder - - def set_output_embeddings(self, value: tf.Variable): - self.decoder.weight = value - self.decoder.vocab_size = shape_list(value)[0] - - def get_bias(self) -> dict[str, tf.Variable]: - return {"bias": self.bias, "decoder_bias": self.decoder_bias} - - def set_bias(self, value: tf.Variable): - self.bias = value["bias"] - self.decoder_bias = value["decoder_bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.LayerNorm(inputs=hidden_states) - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.decoder_bias) - - return hidden_states - - -@keras_serializable -class TFAlbertMainLayer(keras.layers.Layer): - config_class = AlbertConfig - - def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True, **kwargs): - super().__init__(**kwargs) - - self.config = config - - self.embeddings = TFAlbertEmbeddings(config, name="embeddings") - self.encoder = TFAlbertTransformer(config, name="encoder") - self.pooler = ( - keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="pooler", - ) - if add_pooling_layer - else None - ) - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - training=training, - ) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(inputs=sequence_output[:, 0]) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build([None, None, self.config.hidden_size]) - - -@dataclass -class TFAlbertForPreTrainingOutput(ModelOutput): - """ - Output type of [`TFAlbertForPreTraining`]. - - Args: - prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - sop_logits (`tf.Tensor` of shape `(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - prediction_logits: tf.Tensor | None = None - sop_logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -ALBERT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`AlbertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -ALBERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.", - ALBERT_START_DOCSTRING, -) -class TFAlbertModel(TFAlbertPreTrainedModel): - def __init__(self, config: AlbertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.albert = TFAlbertMainLayer(config, name="albert") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPooling, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - outputs = self.albert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "albert", None) is not None: - with tf.name_scope(self.albert.name): - self.albert.build(None) - - -@add_start_docstrings( - """ - Albert Model with two heads on top for pretraining: a `masked language modeling` head and a `sentence order - prediction` (classification) head. - """, - ALBERT_START_DOCSTRING, -) -class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"] - - def __init__(self, config: AlbertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.albert = TFAlbertMainLayer(config, name="albert") - self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions") - self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier") - - def get_lm_head(self) -> keras.layers.Layer: - return self.predictions - - @unpack_inputs - @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - sentence_order_label: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFAlbertForPreTrainingOutput | tuple[tf.Tensor]: - r""" - Return: - - Example: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFAlbertForPreTraining - - >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2") - >>> model = TFAlbertForPreTraining.from_pretrained("albert/albert-base-v2") - - >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] - >>> # Batch size 1 - >>> outputs = model(input_ids) - - >>> prediction_logits = outputs.prediction_logits - >>> sop_logits = outputs.sop_logits - ```""" - - outputs = self.albert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output, pooled_output = outputs[:2] - prediction_scores = self.predictions(hidden_states=sequence_output) - sop_scores = self.sop_classifier(pooled_output=pooled_output, training=training) - total_loss = None - - if labels is not None and sentence_order_label is not None: - d_labels = {"labels": labels} - d_labels["sentence_order_label"] = sentence_order_label - total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, sop_scores)) - - if not return_dict: - output = (prediction_scores, sop_scores) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return TFAlbertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - sop_logits=sop_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "albert", None) is not None: - with tf.name_scope(self.albert.name): - self.albert.build(None) - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - if getattr(self, "sop_classifier", None) is not None: - with tf.name_scope(self.sop_classifier.name): - self.sop_classifier.build(None) - - -class TFAlbertSOPHead(keras.layers.Layer): - def __init__(self, config: AlbertConfig, **kwargs): - super().__init__(**kwargs) - - self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob) - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - ) - self.config = config - - def call(self, pooled_output: tf.Tensor, training: bool) -> tf.Tensor: - dropout_pooled_output = self.dropout(inputs=pooled_output, training=training) - logits = self.classifier(inputs=dropout_pooled_output) - - return logits - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING) -class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"] - - def __init__(self, config: AlbertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert") - self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions") - - def get_lm_head(self) -> keras.layers.Layer: - return self.predictions - - @unpack_inputs - @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - - Returns: - - Example: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFAlbertForMaskedLM - - >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2") - >>> model = TFAlbertForMaskedLM.from_pretrained("albert/albert-base-v2") - - >>> # add mask_token - >>> inputs = tokenizer(f"The capital of [MASK] is Paris.", return_tensors="tf") - >>> logits = model(**inputs).logits - - >>> # retrieve index of [MASK] - >>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1] - >>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1) - >>> tokenizer.decode(predicted_token_id) - 'france' - ``` - - ```python - >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"] - >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) - >>> outputs = model(**inputs, labels=labels) - >>> round(float(outputs.loss), 2) - 0.81 - ``` - """ - outputs = self.albert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.predictions(hidden_states=sequence_output, training=training) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "albert", None) is not None: - with tf.name_scope(self.albert.name): - self.albert.build(None) - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - -@add_start_docstrings( - """ - Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled - output) e.g. for GLUE tasks. - """, - ALBERT_START_DOCSTRING, -) -class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"predictions"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config: AlbertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.albert = TFAlbertMainLayer(config, name="albert") - self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob) - self.classifier = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="vumichien/albert-base-v2-imdb", - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="'LABEL_1'", - expected_loss=0.12, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.albert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=training) - logits = self.classifier(inputs=pooled_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "albert", None) is not None: - with tf.name_scope(self.albert.name): - self.albert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - ALBERT_START_DOCSTRING, -) -class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config: AlbertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert") - classifier_dropout_prob = ( - config.classifier_dropout_prob - if config.classifier_dropout_prob is not None - else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(rate=classifier_dropout_prob) - self.classifier = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.albert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(inputs=sequence_output, training=training) - logits = self.classifier(inputs=sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "albert", None) is not None: - with tf.name_scope(self.albert.name): - self.albert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - ALBERT_START_DOCSTRING, -) -class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"] - - def __init__(self, config: AlbertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert") - self.qa_outputs = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="vumichien/albert-base-v2-squad2", - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - qa_target_start_index=12, - qa_target_end_index=13, - expected_output="'a nice puppet'", - expected_loss=7.36, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.albert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.qa_outputs(inputs=sequence_output) - start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) - start_logits = tf.squeeze(input=start_logits, axis=-1) - end_logits = tf.squeeze(input=end_logits, axis=-1) - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "albert", None) is not None: - with tf.name_scope(self.albert.name): - self.albert.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - ALBERT_START_DOCSTRING, -) -class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config: AlbertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.albert = TFAlbertMainLayer(config, name="albert") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = ( - tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None - ) - flat_token_type_ids = ( - tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None - ) - flat_position_ids = ( - tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None - ) - flat_inputs_embeds = ( - tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - outputs = self.albert( - input_ids=flat_input_ids, - attention_mask=flat_attention_mask, - token_type_ids=flat_token_type_ids, - position_ids=flat_position_ids, - head_mask=head_mask, - inputs_embeds=flat_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=training) - logits = self.classifier(inputs=pooled_output) - reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "albert", None) is not None: - with tf.name_scope(self.albert.name): - self.albert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFAlbertPreTrainedModel", - "TFAlbertModel", - "TFAlbertForPreTraining", - "TFAlbertForMaskedLM", - "TFAlbertForSequenceClassification", - "TFAlbertForTokenClassification", - "TFAlbertForQuestionAnswering", - "TFAlbertForMultipleChoice", - "TFAlbertMainLayer", -] diff --git a/src/transformers/models/align/convert_align_tf_to_hf.py b/src/transformers/models/align/convert_align_tf_to_hf.py deleted file mode 100644 index 74309a0d7076..000000000000 --- a/src/transformers/models/align/convert_align_tf_to_hf.py +++ /dev/null @@ -1,389 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert ALIGN checkpoints from the original repository.""" - -import argparse -import os - -import align -import numpy as np -import requests -import tensorflow as tf -import torch -from PIL import Image -from tokenizer import Tokenizer - -from transformers import ( - AlignConfig, - AlignModel, - AlignProcessor, - BertConfig, - BertTokenizer, - EfficientNetConfig, - EfficientNetImageProcessor, -) -from transformers.utils import logging - - -logging.set_verbosity_info() -logger = logging.get_logger(__name__) - - -def preprocess(image): - image = tf.image.resize(image, (346, 346)) - image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289) - return image - - -def get_align_config(): - vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7") - vision_config.image_size = 289 - vision_config.hidden_dim = 640 - vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"} - vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1} - vision_config.depthwise_padding = [] - - text_config = BertConfig() - config = AlignConfig.from_text_vision_configs( - text_config=text_config, vision_config=vision_config, projection_dim=640 - ) - return config - - -# We will verify our results on an image of cute cats -def prepare_img(): - url = "http://images.cocodataset.org/val2017/000000039769.jpg" - im = Image.open(requests.get(url, stream=True).raw) - return im - - -def get_processor(): - image_processor = EfficientNetImageProcessor( - do_center_crop=True, - rescale_factor=1 / 127.5, - rescale_offset=True, - do_normalize=False, - include_top=False, - resample=Image.BILINEAR, - ) - tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") - tokenizer.model_max_length = 64 - processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer) - return processor - - -# here we list all keys to be renamed (original name on the left, our name on the right) -def rename_keys(original_param_names): - # EfficientNet image encoder - block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")] - block_names = list(set(block_names)) - block_names = sorted(block_names) - num_blocks = len(block_names) - block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))} - - rename_keys = [] - rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight")) - rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight")) - rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias")) - rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean")) - rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var")) - - for b in block_names: - hf_b = block_name_mapping[b] - rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight")) - rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight")) - rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias")) - rename_keys.append( - (f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean") - ) - rename_keys.append( - (f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var") - ) - rename_keys.append( - (f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight") - ) - rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight")) - rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias")) - rename_keys.append( - (f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean") - ) - rename_keys.append( - (f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var") - ) - - rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight")) - rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias")) - rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight")) - rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias")) - rename_keys.append( - (f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight") - ) - rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight")) - rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias")) - rename_keys.append( - (f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean") - ) - rename_keys.append( - (f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var") - ) - - key_mapping = {} - for item in rename_keys: - if item[0] in original_param_names: - key_mapping[item[0]] = "vision_model." + item[1] - - # BERT text encoder - rename_keys = [] - old = "tf_bert_model/bert" - new = "text_model" - for i in range(12): - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0", - f"{new}.encoder.layer.{i}.attention.self.query.weight", - ) - ) - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/attention/self/query/bias:0", - f"{new}.encoder.layer.{i}.attention.self.query.bias", - ) - ) - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0", - f"{new}.encoder.layer.{i}.attention.self.key.weight", - ) - ) - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/attention/self/key/bias:0", - f"{new}.encoder.layer.{i}.attention.self.key.bias", - ) - ) - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0", - f"{new}.encoder.layer.{i}.attention.self.value.weight", - ) - ) - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/attention/self/value/bias:0", - f"{new}.encoder.layer.{i}.attention.self.value.bias", - ) - ) - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0", - f"{new}.encoder.layer.{i}.attention.output.dense.weight", - ) - ) - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0", - f"{new}.encoder.layer.{i}.attention.output.dense.bias", - ) - ) - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0", - f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight", - ) - ) - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0", - f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias", - ) - ) - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0", - f"{new}.encoder.layer.{i}.intermediate.dense.weight", - ) - ) - rename_keys.append( - ( - f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0", - f"{new}.encoder.layer.{i}.intermediate.dense.bias", - ) - ) - rename_keys.append( - (f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight") - ) - rename_keys.append( - (f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias") - ) - rename_keys.append( - (f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight") - ) - rename_keys.append( - (f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias") - ) - - rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight")) - rename_keys.append( - (f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight") - ) - rename_keys.append( - (f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight") - ) - rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight")) - rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias")) - - rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight")) - rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias")) - rename_keys.append(("dense/kernel:0", "text_projection.weight")) - rename_keys.append(("dense/bias:0", "text_projection.bias")) - rename_keys.append(("dense/bias:0", "text_projection.bias")) - rename_keys.append(("temperature:0", "temperature")) - - for item in rename_keys: - if item[0] in original_param_names: - key_mapping[item[0]] = item[1] - return key_mapping - - -def replace_params(hf_params, tf_params, key_mapping): - list(hf_params.keys()) - - for key, value in tf_params.items(): - if key not in key_mapping: - continue - - hf_key = key_mapping[key] - if "_conv" in key and "kernel" in key: - new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1) - elif "embeddings" in key: - new_hf_value = torch.from_numpy(value) - elif "depthwise_kernel" in key: - new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1) - elif "kernel" in key: - new_hf_value = torch.from_numpy(np.transpose(value)) - elif "temperature" in key: - new_hf_value = value - elif "bn/gamma" in key or "bn/beta" in key: - new_hf_value = torch.from_numpy(np.transpose(value)).squeeze() - else: - new_hf_value = torch.from_numpy(value) - - # Replace HF parameters with original TF model parameters - hf_params[hf_key].copy_(new_hf_value) - - -@torch.no_grad() -def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub): - """ - Copy/paste/tweak model's weights to our ALIGN structure. - """ - # Load original model - seq_length = 64 - tok = Tokenizer(seq_length) - original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size()) - original_model.compile() - original_model.load_weights(checkpoint_path) - - tf_params = original_model.trainable_variables - tf_non_train_params = original_model.non_trainable_variables - tf_params = {param.name: param.numpy() for param in tf_params} - for param in tf_non_train_params: - tf_params[param.name] = param.numpy() - tf_param_names = list(tf_params.keys()) - - # Load HuggingFace model - config = get_align_config() - hf_model = AlignModel(config).eval() - hf_params = hf_model.state_dict() - - # Create src-to-dst parameter name mapping dictionary - print("Converting parameters...") - key_mapping = rename_keys(tf_param_names) - replace_params(hf_params, tf_params, key_mapping) - - # Initialize processor - processor = get_processor() - inputs = processor( - images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt" - ) - - # HF model inference - hf_model.eval() - with torch.no_grad(): - outputs = hf_model(**inputs) - - hf_image_features = outputs.image_embeds.detach().numpy() - hf_text_features = outputs.text_embeds.detach().numpy() - - # Original model inference - original_model.trainable = False - tf_image_processor = EfficientNetImageProcessor( - do_center_crop=True, - do_rescale=False, - do_normalize=False, - include_top=False, - resample=Image.BILINEAR, - ) - image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"] - text = tok(tf.constant(["A picture of a cat"])) - - image_features = original_model.image_encoder(image, training=False) - text_features = original_model.text_encoder(text, training=False) - - image_features = tf.nn.l2_normalize(image_features, axis=-1) - text_features = tf.nn.l2_normalize(text_features, axis=-1) - - # Check whether original and HF model outputs match -> np.allclose - if not np.allclose(image_features, hf_image_features, atol=1e-3): - raise ValueError("The predicted image features are not the same.") - if not np.allclose(text_features, hf_text_features, atol=1e-3): - raise ValueError("The predicted text features are not the same.") - print("Model outputs match!") - - if save_model: - # Create folder to save model - if not os.path.isdir(pytorch_dump_folder_path): - os.mkdir(pytorch_dump_folder_path) - # Save converted model and image processor - hf_model.save_pretrained(pytorch_dump_folder_path) - processor.save_pretrained(pytorch_dump_folder_path) - - if push_to_hub: - # Push model and image processor to hub - print("Pushing converted ALIGN to the hub...") - processor.push_to_hub("align-base") - hf_model.push_to_hub("align-base") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--checkpoint_path", - default="./weights/model-weights", - type=str, - help="Path to the pretrained TF ALIGN checkpoint.", - ) - parser.add_argument( - "--pytorch_dump_folder_path", - default="hf_model", - type=str, - help="Path to the output PyTorch model directory.", - ) - parser.add_argument("--save_model", action="store_true", help="Save model to local") - parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub") - - args = parser.parse_args() - convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub) diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py deleted file mode 100644 index 0588d03cb6cd..000000000000 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ /dev/null @@ -1,413 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Auto Model class.""" - -from collections import OrderedDict - -from ...utils import logging -from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update -from .configuration_auto import CONFIG_MAPPING_NAMES - - -logger = logging.get_logger(__name__) - - -FLAX_MODEL_MAPPING_NAMES = OrderedDict( - [ - # Base model mapping - ("albert", "FlaxAlbertModel"), - ("bart", "FlaxBartModel"), - ("beit", "FlaxBeitModel"), - ("bert", "FlaxBertModel"), - ("big_bird", "FlaxBigBirdModel"), - ("blenderbot", "FlaxBlenderbotModel"), - ("blenderbot-small", "FlaxBlenderbotSmallModel"), - ("bloom", "FlaxBloomModel"), - ("clip", "FlaxCLIPModel"), - ("dinov2", "FlaxDinov2Model"), - ("distilbert", "FlaxDistilBertModel"), - ("electra", "FlaxElectraModel"), - ("gemma", "FlaxGemmaModel"), - ("gpt-sw3", "FlaxGPT2Model"), - ("gpt2", "FlaxGPT2Model"), - ("gpt_neo", "FlaxGPTNeoModel"), - ("gptj", "FlaxGPTJModel"), - ("llama", "FlaxLlamaModel"), - ("longt5", "FlaxLongT5Model"), - ("marian", "FlaxMarianModel"), - ("mbart", "FlaxMBartModel"), - ("mistral", "FlaxMistralModel"), - ("mt5", "FlaxMT5Model"), - ("opt", "FlaxOPTModel"), - ("pegasus", "FlaxPegasusModel"), - ("regnet", "FlaxRegNetModel"), - ("resnet", "FlaxResNetModel"), - ("roberta", "FlaxRobertaModel"), - ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"), - ("roformer", "FlaxRoFormerModel"), - ("t5", "FlaxT5Model"), - ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), - ("vit", "FlaxViTModel"), - ("wav2vec2", "FlaxWav2Vec2Model"), - ("whisper", "FlaxWhisperModel"), - ("xglm", "FlaxXGLMModel"), - ("xlm-roberta", "FlaxXLMRobertaModel"), - ] -) - -FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( - [ - # Model for pre-training mapping - ("albert", "FlaxAlbertForPreTraining"), - ("bart", "FlaxBartForConditionalGeneration"), - ("bert", "FlaxBertForPreTraining"), - ("big_bird", "FlaxBigBirdForPreTraining"), - ("electra", "FlaxElectraForPreTraining"), - ("longt5", "FlaxLongT5ForConditionalGeneration"), - ("mbart", "FlaxMBartForConditionalGeneration"), - ("mt5", "FlaxMT5ForConditionalGeneration"), - ("roberta", "FlaxRobertaForMaskedLM"), - ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"), - ("roformer", "FlaxRoFormerForMaskedLM"), - ("t5", "FlaxT5ForConditionalGeneration"), - ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), - ("whisper", "FlaxWhisperForConditionalGeneration"), - ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), - ] -) - -FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( - [ - # Model for Masked LM mapping - ("albert", "FlaxAlbertForMaskedLM"), - ("bart", "FlaxBartForConditionalGeneration"), - ("bert", "FlaxBertForMaskedLM"), - ("big_bird", "FlaxBigBirdForMaskedLM"), - ("distilbert", "FlaxDistilBertForMaskedLM"), - ("electra", "FlaxElectraForMaskedLM"), - ("mbart", "FlaxMBartForConditionalGeneration"), - ("roberta", "FlaxRobertaForMaskedLM"), - ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"), - ("roformer", "FlaxRoFormerForMaskedLM"), - ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), - ] -) - -FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( - [ - # Model for Seq2Seq Causal LM mapping - ("bart", "FlaxBartForConditionalGeneration"), - ("blenderbot", "FlaxBlenderbotForConditionalGeneration"), - ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"), - ("encoder-decoder", "FlaxEncoderDecoderModel"), - ("longt5", "FlaxLongT5ForConditionalGeneration"), - ("marian", "FlaxMarianMTModel"), - ("mbart", "FlaxMBartForConditionalGeneration"), - ("mt5", "FlaxMT5ForConditionalGeneration"), - ("pegasus", "FlaxPegasusForConditionalGeneration"), - ("t5", "FlaxT5ForConditionalGeneration"), - ] -) - -FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - # Model for Image-classification - ("beit", "FlaxBeitForImageClassification"), - ("dinov2", "FlaxDinov2ForImageClassification"), - ("regnet", "FlaxRegNetForImageClassification"), - ("resnet", "FlaxResNetForImageClassification"), - ("vit", "FlaxViTForImageClassification"), - ] -) - -FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( - [ - ("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"), - ] -) - -FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( - [ - # Model for Causal LM mapping - ("bart", "FlaxBartForCausalLM"), - ("bert", "FlaxBertForCausalLM"), - ("big_bird", "FlaxBigBirdForCausalLM"), - ("bloom", "FlaxBloomForCausalLM"), - ("electra", "FlaxElectraForCausalLM"), - ("gemma", "FlaxGemmaForCausalLM"), - ("gpt-sw3", "FlaxGPT2LMHeadModel"), - ("gpt2", "FlaxGPT2LMHeadModel"), - ("gpt_neo", "FlaxGPTNeoForCausalLM"), - ("gptj", "FlaxGPTJForCausalLM"), - ("llama", "FlaxLlamaForCausalLM"), - ("mistral", "FlaxMistralForCausalLM"), - ("opt", "FlaxOPTForCausalLM"), - ("roberta", "FlaxRobertaForCausalLM"), - ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"), - ("xglm", "FlaxXGLMForCausalLM"), - ("xlm-roberta", "FlaxXLMRobertaForCausalLM"), - ] -) - -FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - # Model for Sequence Classification mapping - ("albert", "FlaxAlbertForSequenceClassification"), - ("bart", "FlaxBartForSequenceClassification"), - ("bert", "FlaxBertForSequenceClassification"), - ("big_bird", "FlaxBigBirdForSequenceClassification"), - ("distilbert", "FlaxDistilBertForSequenceClassification"), - ("electra", "FlaxElectraForSequenceClassification"), - ("mbart", "FlaxMBartForSequenceClassification"), - ("roberta", "FlaxRobertaForSequenceClassification"), - ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"), - ("roformer", "FlaxRoFormerForSequenceClassification"), - ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"), - ] -) - -FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( - [ - # Model for Question Answering mapping - ("albert", "FlaxAlbertForQuestionAnswering"), - ("bart", "FlaxBartForQuestionAnswering"), - ("bert", "FlaxBertForQuestionAnswering"), - ("big_bird", "FlaxBigBirdForQuestionAnswering"), - ("distilbert", "FlaxDistilBertForQuestionAnswering"), - ("electra", "FlaxElectraForQuestionAnswering"), - ("mbart", "FlaxMBartForQuestionAnswering"), - ("roberta", "FlaxRobertaForQuestionAnswering"), - ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"), - ("roformer", "FlaxRoFormerForQuestionAnswering"), - ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"), - ] -) - -FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - # Model for Token Classification mapping - ("albert", "FlaxAlbertForTokenClassification"), - ("bert", "FlaxBertForTokenClassification"), - ("big_bird", "FlaxBigBirdForTokenClassification"), - ("distilbert", "FlaxDistilBertForTokenClassification"), - ("electra", "FlaxElectraForTokenClassification"), - ("roberta", "FlaxRobertaForTokenClassification"), - ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"), - ("roformer", "FlaxRoFormerForTokenClassification"), - ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"), - ] -) - -FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( - [ - # Model for Multiple Choice mapping - ("albert", "FlaxAlbertForMultipleChoice"), - ("bert", "FlaxBertForMultipleChoice"), - ("big_bird", "FlaxBigBirdForMultipleChoice"), - ("distilbert", "FlaxDistilBertForMultipleChoice"), - ("electra", "FlaxElectraForMultipleChoice"), - ("roberta", "FlaxRobertaForMultipleChoice"), - ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"), - ("roformer", "FlaxRoFormerForMultipleChoice"), - ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"), - ] -) - -FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( - [ - ("bert", "FlaxBertForNextSentencePrediction"), - ] -) - -FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( - [ - ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"), - ("whisper", "FlaxWhisperForConditionalGeneration"), - ] -) - -FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - ("whisper", "FlaxWhisperForAudioClassification"), - ] -) - -FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES) -FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES) -FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES) -FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES -) -FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES -) -FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) -FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) -FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES -) -FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES -) -FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES -) -FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES -) -FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES -) -FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES -) -FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES -) - - -class FlaxAutoModel(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_MAPPING - - -FlaxAutoModel = auto_class_update(FlaxAutoModel) - - -class FlaxAutoModelForPreTraining(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING - - -FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining") - - -class FlaxAutoModelForCausalLM(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING - - -FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling") - - -class FlaxAutoModelForMaskedLM(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING - - -FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling") - - -class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING - - -FlaxAutoModelForSeq2SeqLM = auto_class_update( - FlaxAutoModelForSeq2SeqLM, - head_doc="sequence-to-sequence language modeling", - checkpoint_for_example="google-t5/t5-base", -) - - -class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING - - -FlaxAutoModelForSequenceClassification = auto_class_update( - FlaxAutoModelForSequenceClassification, head_doc="sequence classification" -) - - -class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING - - -FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering") - - -class FlaxAutoModelForTokenClassification(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING - - -FlaxAutoModelForTokenClassification = auto_class_update( - FlaxAutoModelForTokenClassification, head_doc="token classification" -) - - -class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING - - -FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice") - - -class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING - - -FlaxAutoModelForNextSentencePrediction = auto_class_update( - FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction" -) - - -class FlaxAutoModelForImageClassification(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING - - -FlaxAutoModelForImageClassification = auto_class_update( - FlaxAutoModelForImageClassification, head_doc="image classification" -) - - -class FlaxAutoModelForVision2Seq(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING - - -FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling") - - -class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass): - _model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING - - -FlaxAutoModelForSpeechSeq2Seq = auto_class_update( - FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" -) - -__all__ = [ - "FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", - "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", - "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", - "FLAX_MODEL_FOR_MASKED_LM_MAPPING", - "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", - "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", - "FLAX_MODEL_FOR_PRETRAINING_MAPPING", - "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", - "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", - "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", - "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", - "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", - "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING", - "FLAX_MODEL_MAPPING", - "FlaxAutoModel", - "FlaxAutoModelForCausalLM", - "FlaxAutoModelForImageClassification", - "FlaxAutoModelForMaskedLM", - "FlaxAutoModelForMultipleChoice", - "FlaxAutoModelForNextSentencePrediction", - "FlaxAutoModelForPreTraining", - "FlaxAutoModelForQuestionAnswering", - "FlaxAutoModelForSeq2SeqLM", - "FlaxAutoModelForSequenceClassification", - "FlaxAutoModelForSpeechSeq2Seq", - "FlaxAutoModelForTokenClassification", - "FlaxAutoModelForVision2Seq", -] diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py deleted file mode 100644 index cf39f4d7c9c4..000000000000 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ /dev/null @@ -1,776 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Auto Model class.""" - -import warnings -from collections import OrderedDict - -from ...utils import logging -from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update -from .configuration_auto import CONFIG_MAPPING_NAMES - - -logger = logging.get_logger(__name__) - - -TF_MODEL_MAPPING_NAMES = OrderedDict( - [ - # Base model mapping - ("albert", "TFAlbertModel"), - ("bart", "TFBartModel"), - ("bert", "TFBertModel"), - ("blenderbot", "TFBlenderbotModel"), - ("blenderbot-small", "TFBlenderbotSmallModel"), - ("blip", "TFBlipModel"), - ("camembert", "TFCamembertModel"), - ("clip", "TFCLIPModel"), - ("convbert", "TFConvBertModel"), - ("convnext", "TFConvNextModel"), - ("convnextv2", "TFConvNextV2Model"), - ("ctrl", "TFCTRLModel"), - ("cvt", "TFCvtModel"), - ("data2vec-vision", "TFData2VecVisionModel"), - ("deberta", "TFDebertaModel"), - ("deberta-v2", "TFDebertaV2Model"), - ("deit", "TFDeiTModel"), - ("distilbert", "TFDistilBertModel"), - ("dpr", "TFDPRQuestionEncoder"), - ("efficientformer", "TFEfficientFormerModel"), - ("electra", "TFElectraModel"), - ("esm", "TFEsmModel"), - ("flaubert", "TFFlaubertModel"), - ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")), - ("gpt-sw3", "TFGPT2Model"), - ("gpt2", "TFGPT2Model"), - ("gptj", "TFGPTJModel"), - ("groupvit", "TFGroupViTModel"), - ("hubert", "TFHubertModel"), - ("idefics", "TFIdeficsModel"), - ("layoutlm", "TFLayoutLMModel"), - ("layoutlmv3", "TFLayoutLMv3Model"), - ("led", "TFLEDModel"), - ("longformer", "TFLongformerModel"), - ("lxmert", "TFLxmertModel"), - ("marian", "TFMarianModel"), - ("mbart", "TFMBartModel"), - ("mistral", "TFMistralModel"), - ("mobilebert", "TFMobileBertModel"), - ("mobilevit", "TFMobileViTModel"), - ("mpnet", "TFMPNetModel"), - ("mt5", "TFMT5Model"), - ("openai-gpt", "TFOpenAIGPTModel"), - ("opt", "TFOPTModel"), - ("pegasus", "TFPegasusModel"), - ("regnet", "TFRegNetModel"), - ("rembert", "TFRemBertModel"), - ("resnet", "TFResNetModel"), - ("roberta", "TFRobertaModel"), - ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"), - ("roformer", "TFRoFormerModel"), - ("sam", "TFSamModel"), - ("sam_vision_model", "TFSamVisionModel"), - ("segformer", "TFSegformerModel"), - ("speech_to_text", "TFSpeech2TextModel"), - ("swiftformer", "TFSwiftFormerModel"), - ("swin", "TFSwinModel"), - ("t5", "TFT5Model"), - ("tapas", "TFTapasModel"), - ("transfo-xl", "TFTransfoXLModel"), - ("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"), - ("vit", "TFViTModel"), - ("vit_mae", "TFViTMAEModel"), - ("wav2vec2", "TFWav2Vec2Model"), - ("whisper", "TFWhisperModel"), - ("xglm", "TFXGLMModel"), - ("xlm", "TFXLMModel"), - ("xlm-roberta", "TFXLMRobertaModel"), - ("xlnet", "TFXLNetModel"), - ] -) - -TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( - [ - # Model for pre-training mapping - ("albert", "TFAlbertForPreTraining"), - ("bart", "TFBartForConditionalGeneration"), - ("bert", "TFBertForPreTraining"), - ("camembert", "TFCamembertForMaskedLM"), - ("ctrl", "TFCTRLLMHeadModel"), - ("distilbert", "TFDistilBertForMaskedLM"), - ("electra", "TFElectraForPreTraining"), - ("flaubert", "TFFlaubertWithLMHeadModel"), - ("funnel", "TFFunnelForPreTraining"), - ("gpt-sw3", "TFGPT2LMHeadModel"), - ("gpt2", "TFGPT2LMHeadModel"), - ("idefics", "TFIdeficsForVisionText2Text"), - ("layoutlm", "TFLayoutLMForMaskedLM"), - ("lxmert", "TFLxmertForPreTraining"), - ("mobilebert", "TFMobileBertForPreTraining"), - ("mpnet", "TFMPNetForMaskedLM"), - ("openai-gpt", "TFOpenAIGPTLMHeadModel"), - ("roberta", "TFRobertaForMaskedLM"), - ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"), - ("t5", "TFT5ForConditionalGeneration"), - ("tapas", "TFTapasForMaskedLM"), - ("transfo-xl", "TFTransfoXLLMHeadModel"), - ("vit_mae", "TFViTMAEForPreTraining"), - ("xlm", "TFXLMWithLMHeadModel"), - ("xlm-roberta", "TFXLMRobertaForMaskedLM"), - ("xlnet", "TFXLNetLMHeadModel"), - ] -) - -TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( - [ - # Model with LM heads mapping - ("albert", "TFAlbertForMaskedLM"), - ("bart", "TFBartForConditionalGeneration"), - ("bert", "TFBertForMaskedLM"), - ("camembert", "TFCamembertForMaskedLM"), - ("convbert", "TFConvBertForMaskedLM"), - ("ctrl", "TFCTRLLMHeadModel"), - ("distilbert", "TFDistilBertForMaskedLM"), - ("electra", "TFElectraForMaskedLM"), - ("esm", "TFEsmForMaskedLM"), - ("flaubert", "TFFlaubertWithLMHeadModel"), - ("funnel", "TFFunnelForMaskedLM"), - ("gpt-sw3", "TFGPT2LMHeadModel"), - ("gpt2", "TFGPT2LMHeadModel"), - ("gptj", "TFGPTJForCausalLM"), - ("layoutlm", "TFLayoutLMForMaskedLM"), - ("led", "TFLEDForConditionalGeneration"), - ("longformer", "TFLongformerForMaskedLM"), - ("marian", "TFMarianMTModel"), - ("mobilebert", "TFMobileBertForMaskedLM"), - ("mpnet", "TFMPNetForMaskedLM"), - ("openai-gpt", "TFOpenAIGPTLMHeadModel"), - ("rembert", "TFRemBertForMaskedLM"), - ("roberta", "TFRobertaForMaskedLM"), - ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"), - ("roformer", "TFRoFormerForMaskedLM"), - ("speech_to_text", "TFSpeech2TextForConditionalGeneration"), - ("t5", "TFT5ForConditionalGeneration"), - ("tapas", "TFTapasForMaskedLM"), - ("transfo-xl", "TFTransfoXLLMHeadModel"), - ("whisper", "TFWhisperForConditionalGeneration"), - ("xlm", "TFXLMWithLMHeadModel"), - ("xlm-roberta", "TFXLMRobertaForMaskedLM"), - ("xlnet", "TFXLNetLMHeadModel"), - ] -) - -TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( - [ - # Model for Causal LM mapping - ("bert", "TFBertLMHeadModel"), - ("camembert", "TFCamembertForCausalLM"), - ("ctrl", "TFCTRLLMHeadModel"), - ("gpt-sw3", "TFGPT2LMHeadModel"), - ("gpt2", "TFGPT2LMHeadModel"), - ("gptj", "TFGPTJForCausalLM"), - ("mistral", "TFMistralForCausalLM"), - ("openai-gpt", "TFOpenAIGPTLMHeadModel"), - ("opt", "TFOPTForCausalLM"), - ("rembert", "TFRemBertForCausalLM"), - ("roberta", "TFRobertaForCausalLM"), - ("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"), - ("roformer", "TFRoFormerForCausalLM"), - ("transfo-xl", "TFTransfoXLLMHeadModel"), - ("xglm", "TFXGLMForCausalLM"), - ("xlm", "TFXLMWithLMHeadModel"), - ("xlm-roberta", "TFXLMRobertaForCausalLM"), - ("xlnet", "TFXLNetLMHeadModel"), - ] -) - -TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( - [ - ("deit", "TFDeiTForMaskedImageModeling"), - ("swin", "TFSwinForMaskedImageModeling"), - ] -) - -TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - # Model for Image-classsification - ("convnext", "TFConvNextForImageClassification"), - ("convnextv2", "TFConvNextV2ForImageClassification"), - ("cvt", "TFCvtForImageClassification"), - ("data2vec-vision", "TFData2VecVisionForImageClassification"), - ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")), - ( - "efficientformer", - ("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"), - ), - ("mobilevit", "TFMobileViTForImageClassification"), - ("regnet", "TFRegNetForImageClassification"), - ("resnet", "TFResNetForImageClassification"), - ("segformer", "TFSegformerForImageClassification"), - ("swiftformer", "TFSwiftFormerForImageClassification"), - ("swin", "TFSwinForImageClassification"), - ("vit", "TFViTForImageClassification"), - ] -) - - -TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - # Model for Zero Shot Image Classification mapping - ("blip", "TFBlipModel"), - ("clip", "TFCLIPModel"), - ] -) - - -TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( - [ - # Model for Semantic Segmentation mapping - ("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"), - ("mobilevit", "TFMobileViTForSemanticSegmentation"), - ("segformer", "TFSegformerForSemanticSegmentation"), - ] -) - -TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( - [ - ("blip", "TFBlipForConditionalGeneration"), - ("vision-encoder-decoder", "TFVisionEncoderDecoderModel"), - ] -) - -TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( - [ - # Model for Masked LM mapping - ("albert", "TFAlbertForMaskedLM"), - ("bert", "TFBertForMaskedLM"), - ("camembert", "TFCamembertForMaskedLM"), - ("convbert", "TFConvBertForMaskedLM"), - ("deberta", "TFDebertaForMaskedLM"), - ("deberta-v2", "TFDebertaV2ForMaskedLM"), - ("distilbert", "TFDistilBertForMaskedLM"), - ("electra", "TFElectraForMaskedLM"), - ("esm", "TFEsmForMaskedLM"), - ("flaubert", "TFFlaubertWithLMHeadModel"), - ("funnel", "TFFunnelForMaskedLM"), - ("layoutlm", "TFLayoutLMForMaskedLM"), - ("longformer", "TFLongformerForMaskedLM"), - ("mobilebert", "TFMobileBertForMaskedLM"), - ("mpnet", "TFMPNetForMaskedLM"), - ("rembert", "TFRemBertForMaskedLM"), - ("roberta", "TFRobertaForMaskedLM"), - ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"), - ("roformer", "TFRoFormerForMaskedLM"), - ("tapas", "TFTapasForMaskedLM"), - ("xlm", "TFXLMWithLMHeadModel"), - ("xlm-roberta", "TFXLMRobertaForMaskedLM"), - ] -) - -TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( - [ - # Model for Seq2Seq Causal LM mapping - ("bart", "TFBartForConditionalGeneration"), - ("blenderbot", "TFBlenderbotForConditionalGeneration"), - ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"), - ("encoder-decoder", "TFEncoderDecoderModel"), - ("led", "TFLEDForConditionalGeneration"), - ("marian", "TFMarianMTModel"), - ("mbart", "TFMBartForConditionalGeneration"), - ("mt5", "TFMT5ForConditionalGeneration"), - ("pegasus", "TFPegasusForConditionalGeneration"), - ("t5", "TFT5ForConditionalGeneration"), - ] -) - -TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( - [ - ("speech_to_text", "TFSpeech2TextForConditionalGeneration"), - ("whisper", "TFWhisperForConditionalGeneration"), - ] -) - -TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - # Model for Sequence Classification mapping - ("albert", "TFAlbertForSequenceClassification"), - ("bart", "TFBartForSequenceClassification"), - ("bert", "TFBertForSequenceClassification"), - ("camembert", "TFCamembertForSequenceClassification"), - ("convbert", "TFConvBertForSequenceClassification"), - ("ctrl", "TFCTRLForSequenceClassification"), - ("deberta", "TFDebertaForSequenceClassification"), - ("deberta-v2", "TFDebertaV2ForSequenceClassification"), - ("distilbert", "TFDistilBertForSequenceClassification"), - ("electra", "TFElectraForSequenceClassification"), - ("esm", "TFEsmForSequenceClassification"), - ("flaubert", "TFFlaubertForSequenceClassification"), - ("funnel", "TFFunnelForSequenceClassification"), - ("gpt-sw3", "TFGPT2ForSequenceClassification"), - ("gpt2", "TFGPT2ForSequenceClassification"), - ("gptj", "TFGPTJForSequenceClassification"), - ("layoutlm", "TFLayoutLMForSequenceClassification"), - ("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"), - ("longformer", "TFLongformerForSequenceClassification"), - ("mistral", "TFMistralForSequenceClassification"), - ("mobilebert", "TFMobileBertForSequenceClassification"), - ("mpnet", "TFMPNetForSequenceClassification"), - ("openai-gpt", "TFOpenAIGPTForSequenceClassification"), - ("rembert", "TFRemBertForSequenceClassification"), - ("roberta", "TFRobertaForSequenceClassification"), - ("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"), - ("roformer", "TFRoFormerForSequenceClassification"), - ("tapas", "TFTapasForSequenceClassification"), - ("transfo-xl", "TFTransfoXLForSequenceClassification"), - ("xlm", "TFXLMForSequenceClassification"), - ("xlm-roberta", "TFXLMRobertaForSequenceClassification"), - ("xlnet", "TFXLNetForSequenceClassification"), - ] -) - -TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( - [ - # Model for Question Answering mapping - ("albert", "TFAlbertForQuestionAnswering"), - ("bert", "TFBertForQuestionAnswering"), - ("camembert", "TFCamembertForQuestionAnswering"), - ("convbert", "TFConvBertForQuestionAnswering"), - ("deberta", "TFDebertaForQuestionAnswering"), - ("deberta-v2", "TFDebertaV2ForQuestionAnswering"), - ("distilbert", "TFDistilBertForQuestionAnswering"), - ("electra", "TFElectraForQuestionAnswering"), - ("flaubert", "TFFlaubertForQuestionAnsweringSimple"), - ("funnel", "TFFunnelForQuestionAnswering"), - ("gptj", "TFGPTJForQuestionAnswering"), - ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"), - ("longformer", "TFLongformerForQuestionAnswering"), - ("mobilebert", "TFMobileBertForQuestionAnswering"), - ("mpnet", "TFMPNetForQuestionAnswering"), - ("rembert", "TFRemBertForQuestionAnswering"), - ("roberta", "TFRobertaForQuestionAnswering"), - ("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"), - ("roformer", "TFRoFormerForQuestionAnswering"), - ("xlm", "TFXLMForQuestionAnsweringSimple"), - ("xlm-roberta", "TFXLMRobertaForQuestionAnswering"), - ("xlnet", "TFXLNetForQuestionAnsweringSimple"), - ] -) -TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")]) - -TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( - [ - ("layoutlm", "TFLayoutLMForQuestionAnswering"), - ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"), - ] -) - - -TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( - [ - # Model for Table Question Answering mapping - ("tapas", "TFTapasForQuestionAnswering"), - ] -) - -TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - # Model for Token Classification mapping - ("albert", "TFAlbertForTokenClassification"), - ("bert", "TFBertForTokenClassification"), - ("camembert", "TFCamembertForTokenClassification"), - ("convbert", "TFConvBertForTokenClassification"), - ("deberta", "TFDebertaForTokenClassification"), - ("deberta-v2", "TFDebertaV2ForTokenClassification"), - ("distilbert", "TFDistilBertForTokenClassification"), - ("electra", "TFElectraForTokenClassification"), - ("esm", "TFEsmForTokenClassification"), - ("flaubert", "TFFlaubertForTokenClassification"), - ("funnel", "TFFunnelForTokenClassification"), - ("layoutlm", "TFLayoutLMForTokenClassification"), - ("layoutlmv3", "TFLayoutLMv3ForTokenClassification"), - ("longformer", "TFLongformerForTokenClassification"), - ("mobilebert", "TFMobileBertForTokenClassification"), - ("mpnet", "TFMPNetForTokenClassification"), - ("rembert", "TFRemBertForTokenClassification"), - ("roberta", "TFRobertaForTokenClassification"), - ("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"), - ("roformer", "TFRoFormerForTokenClassification"), - ("xlm", "TFXLMForTokenClassification"), - ("xlm-roberta", "TFXLMRobertaForTokenClassification"), - ("xlnet", "TFXLNetForTokenClassification"), - ] -) - -TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( - [ - # Model for Multiple Choice mapping - ("albert", "TFAlbertForMultipleChoice"), - ("bert", "TFBertForMultipleChoice"), - ("camembert", "TFCamembertForMultipleChoice"), - ("convbert", "TFConvBertForMultipleChoice"), - ("deberta-v2", "TFDebertaV2ForMultipleChoice"), - ("distilbert", "TFDistilBertForMultipleChoice"), - ("electra", "TFElectraForMultipleChoice"), - ("flaubert", "TFFlaubertForMultipleChoice"), - ("funnel", "TFFunnelForMultipleChoice"), - ("longformer", "TFLongformerForMultipleChoice"), - ("mobilebert", "TFMobileBertForMultipleChoice"), - ("mpnet", "TFMPNetForMultipleChoice"), - ("rembert", "TFRemBertForMultipleChoice"), - ("roberta", "TFRobertaForMultipleChoice"), - ("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"), - ("roformer", "TFRoFormerForMultipleChoice"), - ("xlm", "TFXLMForMultipleChoice"), - ("xlm-roberta", "TFXLMRobertaForMultipleChoice"), - ("xlnet", "TFXLNetForMultipleChoice"), - ] -) - -TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( - [ - ("bert", "TFBertForNextSentencePrediction"), - ("mobilebert", "TFMobileBertForNextSentencePrediction"), - ] -) -TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( - [ - ("sam", "TFSamModel"), - ] -) -TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( - [ - ("albert", "TFAlbertModel"), - ("bert", "TFBertModel"), - ("convbert", "TFConvBertModel"), - ("deberta", "TFDebertaModel"), - ("deberta-v2", "TFDebertaV2Model"), - ("distilbert", "TFDistilBertModel"), - ("electra", "TFElectraModel"), - ("flaubert", "TFFlaubertModel"), - ("longformer", "TFLongformerModel"), - ("mobilebert", "TFMobileBertModel"), - ("mt5", "TFMT5EncoderModel"), - ("rembert", "TFRemBertModel"), - ("roberta", "TFRobertaModel"), - ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"), - ("roformer", "TFRoFormerModel"), - ("t5", "TFT5EncoderModel"), - ("xlm", "TFXLMModel"), - ("xlm-roberta", "TFXLMRobertaModel"), - ] -) - -TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES) -TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) -TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES) -TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) -TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES -) -TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES -) -TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES -) -TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES -) -TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) -TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES) -TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES -) -TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES -) -TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES -) -TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES -) -TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES -) -TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES -) -TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES -) -TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES -) -TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES -) -TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES -) - -TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES -) - -TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) - - -class TFAutoModelForMaskGeneration(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING - - -class TFAutoModelForTextEncoding(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING - - -class TFAutoModel(_BaseAutoModelClass): - _model_mapping = TF_MODEL_MAPPING - - -TFAutoModel = auto_class_update(TFAutoModel) - - -class TFAutoModelForAudioClassification(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING - - -TFAutoModelForAudioClassification = auto_class_update( - TFAutoModelForAudioClassification, head_doc="audio classification" -) - - -class TFAutoModelForPreTraining(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING - - -TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining") - - -# Private on purpose, the public class will add the deprecation warnings. -class _TFAutoModelWithLMHead(_BaseAutoModelClass): - _model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING - - -_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling") - - -class TFAutoModelForCausalLM(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING - - -TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling") - - -class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING - - -TFAutoModelForMaskedImageModeling = auto_class_update( - TFAutoModelForMaskedImageModeling, head_doc="masked image modeling" -) - - -class TFAutoModelForImageClassification(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING - - -TFAutoModelForImageClassification = auto_class_update( - TFAutoModelForImageClassification, head_doc="image classification" -) - - -class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING - - -TFAutoModelForZeroShotImageClassification = auto_class_update( - TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" -) - - -class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING - - -TFAutoModelForSemanticSegmentation = auto_class_update( - TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation" -) - - -class TFAutoModelForVision2Seq(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING - - -TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling") - - -class TFAutoModelForMaskedLM(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING - - -TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling") - - -class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING - - -TFAutoModelForSeq2SeqLM = auto_class_update( - TFAutoModelForSeq2SeqLM, - head_doc="sequence-to-sequence language modeling", - checkpoint_for_example="google-t5/t5-base", -) - - -class TFAutoModelForSequenceClassification(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING - - -TFAutoModelForSequenceClassification = auto_class_update( - TFAutoModelForSequenceClassification, head_doc="sequence classification" -) - - -class TFAutoModelForQuestionAnswering(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING - - -TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering") - - -class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING - - -TFAutoModelForDocumentQuestionAnswering = auto_class_update( - TFAutoModelForDocumentQuestionAnswering, - head_doc="document question answering", - checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3', -) - - -class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING - - -TFAutoModelForTableQuestionAnswering = auto_class_update( - TFAutoModelForTableQuestionAnswering, - head_doc="table question answering", - checkpoint_for_example="google/tapas-base-finetuned-wtq", -) - - -class TFAutoModelForTokenClassification(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING - - -TFAutoModelForTokenClassification = auto_class_update( - TFAutoModelForTokenClassification, head_doc="token classification" -) - - -class TFAutoModelForMultipleChoice(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING - - -TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice") - - -class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING - - -TFAutoModelForNextSentencePrediction = auto_class_update( - TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction" -) - - -class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass): - _model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING - - -TFAutoModelForSpeechSeq2Seq = auto_class_update( - TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" -) - - -class TFAutoModelWithLMHead(_TFAutoModelWithLMHead): - @classmethod - def from_config(cls, config): - warnings.warn( - "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use" - " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models" - " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.", - FutureWarning, - ) - return super().from_config(config) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - warnings.warn( - "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use" - " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models" - " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.", - FutureWarning, - ) - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - - -__all__ = [ - "TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", - "TF_MODEL_FOR_CAUSAL_LM_MAPPING", - "TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", - "TF_MODEL_FOR_MASK_GENERATION_MAPPING", - "TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", - "TF_MODEL_FOR_MASKED_LM_MAPPING", - "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", - "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", - "TF_MODEL_FOR_PRETRAINING_MAPPING", - "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING", - "TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", - "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", - "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", - "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", - "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", - "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", - "TF_MODEL_FOR_TEXT_ENCODING_MAPPING", - "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", - "TF_MODEL_FOR_VISION_2_SEQ_MAPPING", - "TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", - "TF_MODEL_MAPPING", - "TF_MODEL_WITH_LM_HEAD_MAPPING", - "TFAutoModel", - "TFAutoModelForAudioClassification", - "TFAutoModelForCausalLM", - "TFAutoModelForImageClassification", - "TFAutoModelForMaskedImageModeling", - "TFAutoModelForMaskedLM", - "TFAutoModelForMaskGeneration", - "TFAutoModelForMultipleChoice", - "TFAutoModelForNextSentencePrediction", - "TFAutoModelForPreTraining", - "TFAutoModelForDocumentQuestionAnswering", - "TFAutoModelForQuestionAnswering", - "TFAutoModelForSemanticSegmentation", - "TFAutoModelForSeq2SeqLM", - "TFAutoModelForSequenceClassification", - "TFAutoModelForSpeechSeq2Seq", - "TFAutoModelForTableQuestionAnswering", - "TFAutoModelForTextEncoding", - "TFAutoModelForTokenClassification", - "TFAutoModelForVision2Seq", - "TFAutoModelForZeroShotImageClassification", - "TFAutoModelWithLMHead", -] diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py deleted file mode 100644 index 818254f3bfa1..000000000000 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ /dev/null @@ -1,2006 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax Bart model.""" - -import math -import random -from functools import partial -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxSeq2SeqLMOutput, - FlaxSeq2SeqModelOutput, - FlaxSeq2SeqQuestionAnsweringModelOutput, - FlaxSeq2SeqSequenceClassifierOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_bart import BartConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/bart-base" -_CONFIG_FOR_DOC = "BartConfig" - - -BART_START_DOCSTRING = r""" - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`BartConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -BART_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -BART_ENCODE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -BART_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: - """ - Shift input ids one token to the right. - """ - shifted_input_ids = jnp.zeros_like(input_ids) - shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) - shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) - - shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) - return shifted_input_ids - - -class FlaxBartAttention(nn.Module): - config: BartConfig - embed_dim: int - num_heads: int - dropout: float = 0.0 - causal: bool = False - bias: bool = True - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self) -> None: - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {self.num_heads})." - ) - - dense = partial( - nn.Dense, - self.embed_dim, - use_bias=self.bias, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() - self.out_proj = dense() - - self.dropout_layer = nn.Dropout(rate=self.dropout) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states: jnp.ndarray, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states) - value_states = self.v_proj(key_value_states) - else: - # self_attention - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class FlaxBartEncoderLayer(nn.Module): - config: BartConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxBartAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.encoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) - - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class FlaxBartEncoderLayerCollection(nn.Module): - config: BartConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxBartEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers) - ] - self.layerdrop = self.config.encoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): # skip the layer - layer_outputs = (None, None) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions, - deterministic, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -class FlaxBartDecoderLayer(nn.Module): - config: BartConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxBartAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - causal=True, - dtype=self.dtype, - ) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.encoder_attn = FlaxBartAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.fc1 = nn.Dense( - self.config.decoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - hidden_states, cross_attn_weights = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # Fully Connected - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs - - -class FlaxBartDecoderLayerCollection(nn.Module): - config: BartConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxBartDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers) - ] - self.layerdrop = self.config.decoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): - layer_outputs = (None, None, None) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - deterministic=deterministic, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class FlaxBartClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - config: BartConfig - inner_dim: int - num_classes: int - pooler_dropout: float - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense( - self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.dropout = nn.Dropout(rate=self.pooler_dropout) - self.out_proj = nn.Dense( - self.num_classes, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.dense(hidden_states) - hidden_states = jnp.tanh(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -class FlaxBartEncoder(nn.Module): - config: BartConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_source_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 - - # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - self.embed_positions = nn.Embed( - self.config.max_position_embeddings + self.offset, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype) - self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(position_ids + self.offset) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return outputs - - return FlaxBaseModelOutput( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxBartDecoder(nn.Module): - config: BartConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_target_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 - - # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - self.embed_positions = nn.Embed( - self.config.max_position_embeddings + self.offset, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - - self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) - self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - # embed positions - positions = self.embed_positions(position_ids + self.offset) - - hidden_states = inputs_embeds + positions - hidden_states = self.layernorm_embedding(hidden_states) - - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return outputs - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -class FlaxBartModule(nn.Module): - config: BartConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - - self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -class FlaxBartPreTrainedModel(FlaxPreTrainedModel): - config_class = BartConfig - base_model_prefix: str = "model" - module_class: nn.Module = None - - def __init__( - self, - config: BartConfig, - input_shape: tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - # make sure initialization pass will work for FlaxBartForSequenceClassificationModule - input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) - attention_mask = jnp.ones_like(input_ids) - decoder_input_ids = input_ids - decoder_attention_mask = jnp.ones_like(input_ids) - - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(BART_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=BartConfig) - def encode( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration - - >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_ids, attention_mask, position_ids, **kwargs) - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - @add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=BartConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration - - >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> last_decoder_hidden_states = outputs.last_hidden_state - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxBartAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # prepare decoder inputs - if decoder_input_ids is None: - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - if decoder_position_ids is None: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - -@add_start_docstrings( - "The bare Bart Model transformer outputting raw hidden-states without any specific head on top.", - BART_START_DOCSTRING, -) -class FlaxBartModel(FlaxBartPreTrainedModel): - config: BartConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - module_class = FlaxBartModule - - -append_call_sample_docstring(FlaxBartModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) - - -class FlaxBartForConditionalGenerationModule(nn.Module): - config: BartConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.model = FlaxBartModule(config=self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.model.shared.num_embeddings, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = self.model.variables["params"]["shared"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype)) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return output - - return FlaxSeq2SeqLMOutput( - logits=lm_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING -) -class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel): - module_class = FlaxBartForConditionalGenerationModule - dtype: jnp.dtype = jnp.float32 - - @add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=BartConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration - - >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxBartAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - outputs = decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = module.model.variables["params"]["shared"]["embedding"] - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = module.lm_head(hidden_states) - - lm_logits += module.final_logits_bias.astype(self.dtype) - return lm_logits, outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 - return model_kwargs - - -FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING = """ - Returns: - - Summarization example: - - ```python - >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration - - >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") - - >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np") - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs["input_ids"]).sequences - >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) - ``` - - Mask filling example: - - ```python - >>> import jax - >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration - - >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") - - >>> TXT = "My friends are but they eat too many carbs." - >>> input_ids = tokenizer([TXT], return_tensors="jax")["input_ids"] - - >>> logits = model(input_ids).logits - >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item() - >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) - >>> values, predictions = jax.lax.top_k(probs, k=1) - - >>> tokenizer.decode(predictions).split() - ``` -""" - -overwrite_call_docstring( - FlaxBartForConditionalGeneration, BART_INPUTS_DOCSTRING + FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING -) -append_replace_return_docstrings( - FlaxBartForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC -) - - -class FlaxBartForSequenceClassificationModule(nn.Module): - config: BartConfig - dtype: jnp.dtype = jnp.float32 - num_labels: Optional[int] = None - - def setup(self): - self.model = FlaxBartModule(config=self.config, dtype=self.dtype) - self.classification_head = FlaxBartClassificationHead( - config=self.config, - inner_dim=self.config.d_model, - num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels, - pooler_dropout=self.config.classifier_dropout, - ) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] # last hidden state - - eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0) - - # The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation - if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer): - if len(jnp.unique(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - - if any(eos_mask.sum(1) == 0): - raise ValueError("There are missing tokens in input_ids") - - # Ensure to keep 1 only for the last token for each example - eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6 - eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0) - - sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1) - logits = self.classification_head(sentence_representation, deterministic=deterministic) - - if not return_dict: - output = (logits,) + outputs[1:] - return output - - return FlaxSeq2SeqSequenceClassifierOutput( - logits=logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - """ - Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE - tasks. - """, - BART_START_DOCSTRING, -) -class FlaxBartForSequenceClassification(FlaxBartPreTrainedModel): - module_class = FlaxBartForSequenceClassificationModule - dtype = jnp.float32 - - -append_call_sample_docstring( - FlaxBartForSequenceClassification, - _CHECKPOINT_FOR_DOC, - FlaxSeq2SeqSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxBartForQuestionAnsweringModule(nn.Module): - config: BartConfig - dtype: jnp.dtype = jnp.float32 - num_labels = 2 - - def setup(self): - self.model = FlaxBartModule(config=self.config, dtype=self.dtype) - self.qa_outputs = nn.Dense( - self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - output = (start_logits, end_logits) + outputs[1:] - return output - - return FlaxSeq2SeqQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - """ - BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - BART_START_DOCSTRING, -) -class FlaxBartForQuestionAnswering(FlaxBartPreTrainedModel): - module_class = FlaxBartForQuestionAnsweringModule - dtype = jnp.float32 - - -append_call_sample_docstring( - FlaxBartForQuestionAnswering, - _CHECKPOINT_FOR_DOC, - FlaxSeq2SeqQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel): - config_class = BartConfig - base_model_prefix: str = "model" - module_class: nn.Module = None - - def __init__( - self, - config: BartConfig, - input_shape: tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - config.is_decoder = True - config.is_encoder_decoder = False - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - return module_init_outputs["params"] - - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length), dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(BART_DECODE_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - past_key_values: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if encoder_hidden_states is not None and encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - # prepare decoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed - # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be - # changed by FlaxBartAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -class FlaxBartDecoderWrapper(nn.Module): - """ - This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is - used in combination with the [`EncoderDecoderModel`] framework. - """ - - config: BartConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - embed_dim = self.config.d_model - embed_tokens = nn.Embed( - self.config.vocab_size, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype) - - def __call__(self, *args, **kwargs): - return self.decoder(*args, **kwargs) - - -class FlaxBartForCausalLMModule(nn.Module): - config: BartConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids, - attention_mask, - position_ids, - encoder_hidden_states, - encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - """ - Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings) - e.g for autoregressive tasks. - """, - BART_START_DOCSTRING, -) -class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel): - module_class = FlaxBartForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyway. - # Thus, we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxBartForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutputWithCrossAttentions, - _CONFIG_FOR_DOC, -) - - -__all__ = [ - "FlaxBartDecoderPreTrainedModel", - "FlaxBartForCausalLM", - "FlaxBartForConditionalGeneration", - "FlaxBartForQuestionAnswering", - "FlaxBartForSequenceClassification", - "FlaxBartModel", - "FlaxBartPreTrainedModel", -] diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py deleted file mode 100644 index 0a6d2317d696..000000000000 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ /dev/null @@ -1,1713 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 Bart model.""" - -from __future__ import annotations - -import random - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPastAndCrossAttentions, - TFSeq2SeqLMOutput, - TFSeq2SeqModelOutput, - TFSeq2SeqSequenceClassifierOutput, -) - -# Public API -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_end_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_bart import BartConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/bart-large" -_CONFIG_FOR_DOC = "BartConfig" - - -LARGE_NEGATIVE = -1e8 - - -def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - pad_token_id = tf.cast(pad_token_id, input_ids.dtype) - decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) - start_tokens = tf.fill( - (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) - ) - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz = input_ids_shape[0] - tgt_len = input_ids_shape[1] - mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE - mask_cond = tf.range(shape_list(mask)[-1]) - - mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) - - return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) - - -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -class TFBartLearnedPositionalEmbedding(keras.layers.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): - # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs) - - def call( - self, - input_shape: tf.TensorShape | None = None, - past_key_values_length: int = 0, - position_ids: tf.Tensor | None = None, - ): - """Input is expected to be of size [bsz x seqlen].""" - if position_ids is None: - seq_len = input_shape[1] - position_ids = tf.range(seq_len, delta=1, name="range") - position_ids += past_key_values_length - - offset_dtype = position_ids.dtype if isinstance(position_ids, tf.Tensor) else tf.int32 - return super().call(position_ids + tf.constant(self.offset, dtype=offset_dtype)) - - -class TFBartAttention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -class TFBartEncoderLayer(keras.layers.Layer): - def __init__(self, config: BartConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFBartAttention( - self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" - ) - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: np.ndarray | tf.Tensor | None, - layer_head_mask: tf.Tensor | None, - training: bool | None = False, - ) -> tf.Tensor: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)` - """ - residual = hidden_states - hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask - ) - - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) - - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - return hidden_states, self_attn_weights - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.encoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFBartDecoderLayer(keras.layers.Layer): - def __init__(self, config: BartConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFBartAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - name="self_attn", - is_decoder=True, - ) - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.encoder_attn = TFBartAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - name="encoder_attn", - is_decoder=True, - ) - self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") - self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - cross_attn_layer_head_mask: tf.Tensor | None = None, - past_key_value: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - encoder_hidden_states (`tf.Tensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` - encoder_attention_mask (`tf.Tensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - `(decoder_attention_heads,)` - cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. - `(decoder_attention_heads,)` - past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states - """ - residual = hidden_states - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - return ( - hidden_states, - self_attn_weights, - cross_attn_weights, - present_key_value, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "encoder_attn", None) is not None: - with tf.name_scope(self.encoder_attn.name): - self.encoder_attn.build(None) - if getattr(self, "encoder_attn_layer_norm", None) is not None: - with tf.name_scope(self.encoder_attn_layer_norm.name): - self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.decoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFBartClassificationHead(keras.layers.Layer): - """Head for sentence-level classification tasks.""" - - def __init__(self, inner_dim: int, num_classes: int, pooler_dropout: float, name: str, **kwargs): - super().__init__(name=name, **kwargs) - self.dense = keras.layers.Dense(inner_dim, name="dense") - self.dropout = keras.layers.Dropout(pooler_dropout) - self.out_proj = keras.layers.Dense(num_classes, name="out_proj") - self.input_dim = inner_dim - self.inner_dim = inner_dim - - def call(self, inputs): - hidden_states = self.dropout(inputs) - hidden_states = self.dense(hidden_states) - hidden_states = keras.activations.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.input_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.inner_dim]) - - -class TFBartPretrainedModel(TFPreTrainedModel): - config_class = BartConfig - base_model_prefix = "model" - - @property - def dummy_inputs(self): - dummy_inputs = super().dummy_inputs - # Dummy inputs should not contain the default val of 1 - # as this is the padding token and some assertions check it - dummy_inputs["input_ids"] = dummy_inputs["input_ids"] * 2 - if "decoder_input_ids" in dummy_inputs: - dummy_inputs["decoder_input_ids"] = dummy_inputs["decoder_input_ids"] * 2 - return dummy_inputs - - def tf_to_pt_weight_rename(self, tf_weight): - if tf_weight == "model.shared.weight": - return tf_weight, "model.decoder.embed_tokens.weight" - else: - return (tf_weight,) - - -BART_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`BartConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -BART_GENERATION_EXAMPLE = r""" - Summarization example: - - ```python - >>> from transformers import AutoTokenizer, TFBartForConditionalGeneration - - >>> model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") - - >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="tf") - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=5) - >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) - ``` - - Mask filling example: - - ```python - >>> from transformers import AutoTokenizer, TFBartForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") - >>> TXT = "My friends are but they eat too many carbs." - - >>> model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large") - >>> input_ids = tokenizer([TXT], return_tensors="tf")["input_ids"] - >>> logits = model(input_ids).logits - >>> probs = tf.nn.softmax(logits[0]) - >>> # probs[5] is associated with the mask token - ``` -""" - - -BART_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` - is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. - decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tf.FloatTensor`, *optional*): - hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - of shape `(batch_size, sequence_length, hidden_size)` is a sequence of - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@keras_serializable -class TFBartEncoder(keras.layers.Layer): - config_class = BartConfig - """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`TFBartEncoderLayer`]. - - Args: - config: BartConfig - """ - - def __init__(self, config: BartConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.dropout = keras.layers.Dropout(config.dropout) - self.layerdrop = config.encoder_layerdrop - self.padding_idx = config.pad_token_id - self.max_source_positions = config.max_position_embeddings - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - - self.embed_tokens = embed_tokens - self.embed_positions = TFBartLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] - self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") - self.embed_dim = config.d_model - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - """ - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(input_shape) - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - - # check attention mask and invert - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask) - else: - attention_mask = None - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - tf.debugging.assert_equal( - shape_list(head_mask)[0], - len(self.layers), - message=( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(head_mask)[0]}." - ), - ) - - # encoder layers - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): # skip the layer - continue - - hidden_states, attn = encoder_layer( - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - ) - - if output_attentions: - all_attentions += (attn,) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layernorm_embedding", None) is not None: - with tf.name_scope(self.layernorm_embedding.name): - self.layernorm_embedding.build([None, None, self.embed_dim]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFBartDecoder(keras.layers.Layer): - config_class = BartConfig - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFBartDecoderLayer`] - - Args: - config: BartConfig - embed_tokens: output embedding - """ - - def __init__(self, config: BartConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.padding_idx = config.pad_token_id - self.embed_tokens = embed_tokens - self.layerdrop = config.decoder_layerdrop - self.embed_positions = TFBartLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - self.layers = [TFBartDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] - self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") - - self.dropout = keras.layers.Dropout(config.dropout) - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up - decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.tTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 - - # embed positions - if position_ids is None: - positions = self.embed_positions(input_shape, past_key_values_length) - else: - positions = self.embed_positions(input_shape, position_ids=position_ids) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - hidden_states = inputs_embeds - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) - else: - combined_attention_mask = _expand_mask( - tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] - ) - - if attention_mask is not None: - combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) - - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) - - hidden_states = self.layernorm_embedding(hidden_states + positions) - hidden_states = self.dropout(hidden_states, training=training) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None - present_key_values = () if use_cache else None - - # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired - for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: - if attn_mask is not None: - tf.debugging.assert_equal( - shape_list(attn_mask)[0], - len(self.layers), - message=( - f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(attn_mask)[0]}." - ), - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - - dropout_probability = random.uniform(0, 1) - - if training and (dropout_probability < self.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - past_key_value=past_key_value, - ) - - if use_cache: - present_key_values += (present_key_value,) - - if output_attentions: - all_self_attns += (layer_self_attn,) - - if encoder_hidden_states is not None: - all_cross_attns += (layer_cross_attn,) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns - else: - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layernorm_embedding", None) is not None: - with tf.name_scope(self.layernorm_embedding.name): - self.layernorm_embedding.build([None, None, self.config.d_model]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFBartMainLayer(keras.layers.Layer): - config_class = BartConfig - - def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.shared = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.d_model, - embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), - name="model.shared", - ) - # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) - self.shared.load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix - - self.encoder = TFBartEncoder(config, self.shared, name="encoder") - self.decoder = TFBartDecoder(config, self.shared, name="decoder") - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.embed_tokens = self.shared - self.decoder.embed_tokens = self.shared - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs, - ) -> TFSeq2SeqModelOutput | tuple[tf.Tensor]: - # different to other models, Bart automatically creates decoder_input_ids from - # input_ids if no decoder_input_ids are provided - if decoder_input_ids is None and decoder_inputs_embeds is None: - if input_ids is None: - raise ValueError( - "If no `decoder_input_ids` or `decoder_inputs_embeds` are " - "passed, `input_ids` cannot be `None`. Please pass either " - "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." - ) - - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): - encoder_outputs = TFBaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not return_dict and not isinstance(encoder_outputs, tuple): - encoder_outputs = encoder_outputs.to_tuple() - - decoder_outputs = self.decoder( - decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return TFSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - # The shared/tied weights expect to be in the model base namespace - # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than - # the current one. - with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): - self.shared.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - "The bare BART Model outputting raw hidden-states without any specific head on top.", - BART_START_DOCSTRING, -) -class TFBartModel(TFBartPretrainedModel): - _requires_load_weight_prefix = True - - def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model") - - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - - @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSeq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqModelOutput( - last_hidden_state=output.last_hidden_state, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -class BiasLayer(keras.layers.Layer): - """ - Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, - so all weights have to be registered in a layer. - """ - - def __init__(self, shape, initializer, trainable, name, **kwargs): - super().__init__(name=name, **kwargs) - # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of - # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: - # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 - self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) - - def call(self, x): - return x + self.bias - - -@add_start_docstrings( - "The BART Model with a language modeling head. Can be used for summarization.", - BART_START_DOCSTRING, -) -class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss): - _keys_to_ignore_on_load_missing = [r"final_logits_bias"] - _requires_load_weight_prefix = True - - def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model") - self.use_cache = config.use_cache - # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False - ) - - def get_decoder(self): - return self.model.decoder - - def get_encoder(self): - return self.model.encoder - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - def get_bias(self): - return {"final_logits_bias": self.bias_layer.bias} - - def set_bias(self, value): - # Replaces the existing layers containing bias for correct (de)serialization. - vocab_size = value["final_logits_bias"].shape[-1] - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False - ) - self.bias_layer.bias.assign(value["final_logits_bias"]) - - @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - @add_end_docstrings(BART_GENERATION_EXAMPLE) - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSeq2SeqLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - """ - - if labels is not None: - labels = tf.where( - labels == self.config.pad_token_id, - tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), - labels, - ) - use_cache = False - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) - lm_logits = self.bias_layer(lm_logits) - masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - return TFSeq2SeqLMOutput( - loss=masked_lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, # index 1 of d outputs - decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs - decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs - cross_attentions=outputs.cross_attentions, # index 4 of d outputs - encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs - encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out - encoder_attentions=outputs.encoder_attentions, # 2 of e out - ) - - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - if decoder_attention_mask is not None: # xla - decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] - else: # no xla + no past_key_values - decoder_position_ids = tf.range(decoder_input_ids.shape[1]) - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "decoder_position_ids": decoder_position_ids, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - - def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "bias_layer", None) is not None: - with tf.name_scope(self.bias_layer.name): - self.bias_layer.build(None) - - -@add_start_docstrings( - """ - Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE - tasks. - """, - BART_START_DOCSTRING, -) -class TFBartForSequenceClassification(TFBartPretrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model") - self.classification_head = TFBartClassificationHead( - config.d_model, config.num_labels, config.classifier_dropout, name="classification_head" - ) - - @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSeq2SeqSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - Returns: - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - - if input_ids is None and inputs_embeds is not None: - raise NotImplementedError( - f"Passing input embeddings is currently not supported for {self.__class__.__name__}" - ) - - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - last_hidden_state = outputs[0] - eos_mask = tf.equal(input_ids, self.config.eos_token_id) - # out the rows with False where present. Then verify all the final - # entries are True - self_masked = tf.reshape(tf.boolean_mask(eos_mask, eos_mask), (tf.shape(input_ids)[0], -1)) - tf.Assert(tf.reduce_all(self_masked[:, -1]), ["All examples must have the same number of tokens."]) - - masked = tf.reshape( - tf.boolean_mask(last_hidden_state, eos_mask), - (tf.shape(input_ids)[0], tf.shape(self_masked)[1], tf.shape(last_hidden_state)[-1]), - ) - - sentence_representation = masked[:, -1, :] - logits = self.classification_head(sentence_representation) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSeq2SeqSequenceClassifierOutput( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - def serving_output(self, output): - logits = tf.convert_to_tensor(output.logits) - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqSequenceClassifierOutput( - logits=logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "classification_head", None) is not None: - with tf.name_scope(self.classification_head.name): - self.classification_head.build(None) - - -__all__ = ["TFBartForConditionalGeneration", "TFBartForSequenceClassification", "TFBartModel", "TFBartPretrainedModel"] diff --git a/src/transformers/models/beit/modeling_flax_beit.py b/src/transformers/models/beit/modeling_flax_beit.py deleted file mode 100644 index c80deace6b39..000000000000 --- a/src/transformers/models/beit/modeling_flax_beit.py +++ /dev/null @@ -1,956 +0,0 @@ -# coding=utf-8 -# Copyright 2021 Microsoft Research and the HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Callable, Optional - -import flax -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPooling, - FlaxMaskedLMOutput, - FlaxSequenceClassifierOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward -from .configuration_beit import BeitConfig - - -@flax.struct.dataclass -class FlaxBeitModelOutputWithPooling(FlaxBaseModelOutputWithPooling): - """ - Class for outputs of [`FlaxBeitModel`]. - - Args: - last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): - Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if - *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token - will be returned. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus - the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - """ - - -BEIT_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`BeitConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -BEIT_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`AutoImageProcessor.__call__`] for details. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def relative_position_index_init(window_size: tuple[int, int]) -> jnp.ndarray: - """ - get pair-wise relative position index for each token inside the window - """ - num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 - - coords_h = np.arange(window_size[0]) - coords_w = np.arange(window_size[1]) - coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww - coords_flatten = np.reshape(coords, (2, -1)) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = np.transpose(relative_coords, (1, 2, 0)) # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * window_size[1] - 1 - - relative_position_index = np.zeros(shape=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) - relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - relative_position_index[0, 0:] = num_relative_distance - 3 - relative_position_index[0:, 0] = num_relative_distance - 2 - relative_position_index[0, 0] = num_relative_distance - 1 - return jnp.array(relative_position_index) - - -def ones_with_scale(key, shape, scale, dtype=jnp.float32): - return jnp.ones(shape, dtype) * scale - - -class FlaxBeitDropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - rate: float - - @nn.module.compact - def __call__(self, inputs, deterministic: Optional[bool] = True): - if self.rate == 0.0: - return inputs - keep_prob = 1.0 - self.rate - if deterministic: - return inputs - else: - shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - rng = self.make_rng("droppath") - random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype) - binary_tensor = jnp.floor(random_tensor) - output = inputs / keep_prob * binary_tensor - return output - - -class FlaxBeitPatchEmbeddings(nn.Module): - config: BeitConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.num_channels = self.config.num_channels - image_size = self.config.image_size - patch_size = self.config.patch_size - num_patches = (image_size // patch_size) * (image_size // patch_size) - patch_shape = (image_size // patch_size, image_size // patch_size) - self.num_patches = num_patches - self.patch_shape = patch_shape - self.projection = nn.Conv( - self.config.hidden_size, - kernel_size=(patch_size, patch_size), - strides=(patch_size, patch_size), - padding="VALID", - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - def __call__(self, pixel_values): - num_channels = pixel_values.shape[-1] - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - embeddings = self.projection(pixel_values) - batch_size, _, _, channels = embeddings.shape - return jnp.reshape(embeddings, (batch_size, -1, channels)) - - -class FlaxBeitEmbeddings(nn.Module): - """Construct the CLS token, position and patch embeddings.""" - - config: BeitConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) - if self.config.use_mask_token: - self.mask_token = self.param("mask_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) - self.patch_embeddings = FlaxBeitPatchEmbeddings(self.config, dtype=self.dtype) - num_patches = self.patch_embeddings.num_patches - if self.config.use_absolute_position_embeddings: - self.position_embeddings = self.param( - "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, pixel_values, bool_masked_pos=None, deterministic=True): - embeddings = self.patch_embeddings(pixel_values) - batch_size, seq_len, _ = embeddings.shape - - cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) - cls_tokens = cls_tokens.astype(embeddings.dtype) - - if bool_masked_pos is not None: - mask_tokens = jnp.broadcast_to(self.mask_token, (batch_size, seq_len, self.config.hidden_size)) - mask_tokens = mask_tokens.astype(embeddings.dtype) - # replace the masked visual tokens by mask_tokens - w = jnp.expand_dims(bool_masked_pos, axis=-1) - embeddings = embeddings * (1 - w) + mask_tokens * w - - embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) - - if self.config.use_absolute_position_embeddings: - embeddings = embeddings + self.position_embeddings.astype(embeddings.dtype) - - embeddings = self.dropout(embeddings, deterministic=deterministic) - return embeddings - - -class FlaxBeitRelativePositionBias(nn.Module): - config: BeitConfig - window_size: tuple[int, int] - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - num_relative_distance = (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) + 3 - self.relative_position_bias_table = self.param( - "relative_position_bias_table", - nn.initializers.zeros, - (num_relative_distance, self.config.num_attention_heads), - ) # 2*Wh-1 * 2*Ww-1, nH - # cls to token & token 2 cls & cls to cls - - self.relative_position_index = relative_position_index_init(self.window_size) - - def __call__(self): - index = self.relative_position_index.reshape(-1) - shape = (self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) - relative_position_bias = self.relative_position_bias_table[index].reshape(shape) # Wh*Ww,Wh*Ww,nH - return jnp.transpose(relative_position_bias, (2, 0, 1)) - - -class FlaxBeitSelfAttention(nn.Module): - config: BeitConfig - window_size: tuple[int, int] - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - if self.config.hidden_size % self.config.num_attention_heads != 0 and not hasattr( - self.config, "embedding_size" - ): - raise ValueError( - f"The hidden size {self.config.hidden_size} is not a multiple of the number of attention " - f"heads {self.config.num_attention_heads}." - ) - - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - use_bias=False, - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - self.relative_position_bias = ( - FlaxBeitRelativePositionBias(self.config, window_size=self.window_size, dtype=self.dtype) - if self.window_size - else None - ) - - def __call__( - self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False - ): - head_dim = self.config.hidden_size // self.config.num_attention_heads - - query_states = self.query(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - value_states = self.value(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - key_states = self.key(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - - attention_bias = jnp.array(0.0, dtype=self.dtype) - # Add relative position bias if present. - if self.relative_position_bias is not None: - attention_bias = jnp.expand_dims(self.relative_position_bias(), 0) - attention_bias = attention_bias.astype(query_states.dtype) - - # Add shared relative position bias if provided. - if relative_position_bias is not None: - attention_bias = attention_bias + relative_position_bias.astype(attention_bias.dtype) - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxBeitSelfOutput(nn.Module): - config: BeitConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxBeitAttention(nn.Module): - config: BeitConfig - window_size: tuple[int, int] - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.attention = FlaxBeitSelfAttention(self.config, self.window_size, dtype=self.dtype) - self.output = FlaxBeitSelfOutput(self.config, dtype=self.dtype) - - def __call__( - self, hidden_states, relative_position_bias=None, deterministic=True, output_attentions: bool = False - ): - attn_outputs = self.attention( - hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions - ) - attn_output = attn_outputs[0] - attn_output = self.output(attn_output, deterministic=deterministic) - - outputs = (attn_output,) - - if output_attentions: - outputs += (attn_outputs[1],) - - return outputs - - -class FlaxBeitIntermediate(nn.Module): - config: BeitConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.activation = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - - return hidden_states - - -class FlaxBeitOutput(nn.Module): - config: BeitConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - - return hidden_states - - -class FlaxBeitLayer(nn.Module): - config: BeitConfig - window_size: tuple[int, int] - drop_path_rate: float - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.attention = FlaxBeitAttention(self.config, self.window_size, dtype=self.dtype) - self.intermediate = FlaxBeitIntermediate(self.config, dtype=self.dtype) - self.output = FlaxBeitOutput(self.config, dtype=self.dtype) - self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.drop_path = FlaxBeitDropPath(rate=self.drop_path_rate) - self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - self.init_values = self.config.layer_scale_init_value - if self.init_values > 0: - self.lambda_1 = self.param("lambda_1", ones_with_scale, (self.config.hidden_size), self.init_values) - self.lambda_2 = self.param("lambda_2", ones_with_scale, (self.config.hidden_size), self.init_values) - else: - self.lambda_1 = None - self.lambda_2 = None - - def __call__( - self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False - ): - self_attention_outputs = self.attention( - self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention - relative_position_bias, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = self_attention_outputs[0] - - # apply lambda_1 if present - if self.lambda_1 is not None: - attention_output = self.lambda_1.astype(attention_output.dtype) * attention_output - - # first residual connection - hidden_states = self.drop_path(attention_output, deterministic=deterministic) + hidden_states - - # in BEiT, layernorm is also applied after self-attention - layer_output = self.layernorm_after(hidden_states) - - layer_output = self.intermediate(layer_output) - layer_output = self.output(layer_output, deterministic=deterministic) - - # apply lambda_2 if present - if self.lambda_2 is not None: - layer_output = self.lambda_2.astype(layer_output.dtype) * layer_output - - # second residual connection - layer_output = self.drop_path(layer_output, deterministic=deterministic) + hidden_states - - outputs = (layer_output,) - - if output_attentions: - outputs += (self_attention_outputs[1],) - - return outputs - - -class FlaxBeitLayerCollection(nn.Module): - config: BeitConfig - window_size: tuple[int, int] - drop_path_rates: list[float] - relative_position_bias: Callable[[], jnp.ndarray] - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxBeitLayer( - self.config, - window_size=self.window_size if self.config.use_relative_position_bias else None, - drop_path_rate=self.drop_path_rates[i], - name=str(i), - dtype=self.dtype, - ) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - relative_position_bias = self.relative_position_bias() if self.relative_position_bias is not None else None - layer_outputs = layer( - hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states,) - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -class FlaxBeitEncoder(nn.Module): - config: BeitConfig - window_size: tuple[int, int] - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - if self.config.use_shared_relative_position_bias: - self.relative_position_bias = FlaxBeitRelativePositionBias( - config=self.config, window_size=self.window_size, dtype=self.dtype - ) - - # stochastic depth decay rule - drop_path_rates = list(np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers)) - self.layer = FlaxBeitLayerCollection( - self.config, - window_size=self.window_size, - drop_path_rates=drop_path_rates, - relative_position_bias=self.relative_position_bias - if self.config.use_shared_relative_position_bias - else None, - dtype=self.dtype, - ) - - def __call__( - self, - hidden_states, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.layer( - hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class FlaxBeitPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BeitConfig - base_model_prefix = "beit" - main_input_name = "pixel_values" - module_class: nn.Module = None - - def __init__( - self, - config: BeitConfig, - input_shape=None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - if input_shape is None: - input_shape = (1, config.image_size, config.image_size, config.num_channels) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - pixel_values = jnp.zeros(input_shape, dtype=self.dtype) - - params_rng, dropout_rng = jax.random.split(rng) - dropout_rng, droppath_rng = jax.random.split(dropout_rng) - rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng} - - random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - pixel_values, - bool_masked_pos=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - dropout_rng, droppath_rng = jax.random.split(dropout_rng) - rngs["dropout"] = dropout_rng - rngs["droppath"] = droppath_rng - - return self.module.apply( - {"params": params or self.params}, - jnp.array(pixel_values, dtype=jnp.float32), - bool_masked_pos, - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) - - -class FlaxBeitPooler(nn.Module): - config: BeitConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - if self.config.use_mean_pooling: - self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states): - if self.config.use_mean_pooling: - # Mean pool the final hidden states of the patch tokens - patch_tokens = hidden_states[:, 1:, :] - pooled_output = self.layernorm(jnp.mean(patch_tokens, axis=1)) - else: - # Pool by simply taking the final hidden state of the [CLS] token - pooled_output = hidden_states[:, 0] - - return pooled_output - - -class FlaxBeitModule(nn.Module): - config: BeitConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - add_pooling_layer: bool = True - - def setup(self): - self.embeddings = FlaxBeitEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxBeitEncoder( - self.config, window_size=self.embeddings.patch_embeddings.patch_shape, dtype=self.dtype - ) - if not self.config.use_mean_pooling: - self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.pooler = FlaxBeitPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None - - def __call__( - self, - pixel_values, - bool_masked_pos=None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - hidden_states = self.embeddings(pixel_values, bool_masked_pos, deterministic=deterministic) - - outputs = self.encoder( - hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - if not self.config.use_mean_pooling: - hidden_states = self.layernorm(hidden_states) - pooled = self.pooler(hidden_states) if self.add_pooling_layer else None - - if not return_dict: - # if pooled is None, don't return it - if pooled is None: - return (hidden_states,) + outputs[1:] - return (hidden_states, pooled) + outputs[1:] - - return FlaxBeitModelOutputWithPooling( - last_hidden_state=hidden_states, - pooler_output=pooled, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.", - BEIT_START_DOCSTRING, -) -class FlaxBeitModel(FlaxBeitPreTrainedModel): - module_class = FlaxBeitModule - - -FLAX_BEIT_MODEL_DOCSTRING = """ - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, FlaxBeitModel - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k") - >>> model = FlaxBeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k") - - >>> inputs = image_processor(images=image, return_tensors="np") - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.last_hidden_state - ``` -""" - -overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING) -append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig) - - -class FlaxBeitForMaskedImageModelingModule(nn.Module): - config: BeitConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.beit = FlaxBeitModule(self.config, add_pooling_layer=False, dtype=self.dtype) - - # Classifier head - self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.lm_head = nn.Dense( - self.config.vocab_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - - def __call__( - self, - pixel_values=None, - bool_masked_pos=None, - deterministic: bool = True, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.beit( - pixel_values, - bool_masked_pos, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - sequence_output = self.layernorm(sequence_output) - prediction_scores = self.lm_head(sequence_output[:, 1:]) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return output - - return FlaxMaskedLMOutput( - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - "Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).", - BEIT_START_DOCSTRING, -) -class FlaxBeitForMaskedImageModeling(FlaxBeitPreTrainedModel): - module_class = FlaxBeitForMaskedImageModelingModule - - -FLAX_BEIT_MLM_DOCSTRING = """ - bool_masked_pos (`numpy.ndarray` of shape `(batch_size, num_patches)`): - Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). - - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k") - >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k") - - >>> inputs = image_processor(images=image, return_tensors="np") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - ``` -""" - -overwrite_call_docstring(FlaxBeitForMaskedImageModeling, FLAX_BEIT_MLM_DOCSTRING) -append_replace_return_docstrings( - FlaxBeitForMaskedImageModeling, output_type=FlaxMaskedLMOutput, config_class=BeitConfig -) - - -class FlaxBeitForImageClassificationModule(nn.Module): - config: BeitConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.beit = FlaxBeitModule(config=self.config, dtype=self.dtype, add_pooling_layer=True) - self.classifier = nn.Dense( - self.config.num_labels, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - - def __call__( - self, - pixel_values=None, - bool_masked_pos=None, - deterministic: bool = True, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.beit( - pixel_values, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - logits = self.classifier(pooled_output) - - if not return_dict: - output = (logits,) + outputs[2:] - return output - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final - hidden states of the patch tokens) e.g. for ImageNet. - """, - BEIT_START_DOCSTRING, -) -class FlaxBeitForImageClassification(FlaxBeitPreTrainedModel): - module_class = FlaxBeitForImageClassificationModule - - -FLAX_BEIT_CLASSIF_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoImageProcessor, FlaxBeitForImageClassification - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224") - >>> model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224") - - >>> inputs = image_processor(images=image, return_tensors="np") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_class_idx = logits.argmax(-1).item() - >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) - ``` -""" - -overwrite_call_docstring(FlaxBeitForImageClassification, FLAX_BEIT_CLASSIF_DOCSTRING) -append_replace_return_docstrings( - FlaxBeitForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=BeitConfig -) - - -__all__ = [ - "FlaxBeitForImageClassification", - "FlaxBeitForMaskedImageModeling", - "FlaxBeitModel", - "FlaxBeitPreTrainedModel", -] diff --git a/src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py b/src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py deleted file mode 100644 index 9dfd8da474e3..000000000000 --- a/src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now -deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert - -TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert -weight names to the original names, so the model can be imported with Huggingface/transformer. - -You may adapt this script to include classification/MLM/NSP/etc. heads. - -Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0). - Models trained with never versions are not compatible with this script. -""" - -import argparse -import os -import re - -import tensorflow as tf -import torch - -from transformers import BertConfig, BertModel -from transformers.utils import logging - - -logging.set_verbosity_info() -logger = logging.get_logger(__name__) - - -def load_tf2_weights_in_bert(model, tf_checkpoint_path, config): - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - layer_depth = [] - for full_name, shape in init_vars: - # logger.info(f"Loading TF weight {name} with shape {shape}") - name = full_name.split("/") - if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]: - logger.info(f"Skipping non-model layer {full_name}") - continue - if "optimizer" in full_name: - logger.info(f"Skipping optimization layer {full_name}") - continue - if name[0] == "model": - # ignore initial 'model' - name = name[1:] - # figure out how many levels deep the name is - depth = 0 - for _name in name: - if _name.startswith("layer_with_weights"): - depth += 1 - else: - break - layer_depth.append(depth) - # read data - array = tf.train.load_variable(tf_path, full_name) - names.append("/".join(name)) - arrays.append(array) - logger.info(f"Read a total of {len(arrays):,} layers") - - # Sanity check - if len(set(layer_depth)) != 1: - raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})") - layer_depth = list(set(layer_depth))[0] - if layer_depth != 1: - raise ValueError( - "The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP" - " heads." - ) - - # convert layers - logger.info("Converting weights...") - for full_name, array in zip(names, arrays): - name = full_name.split("/") - pointer = model - trace = [] - for i, m_name in enumerate(name): - if m_name == ".ATTRIBUTES": - # variable names end with .ATTRIBUTES/VARIABLE_VALUE - break - if m_name.startswith("layer_with_weights"): - layer_num = int(m_name.split("-")[-1]) - if layer_num <= 2: - # embedding layers - # layer_num 0: word_embeddings - # layer_num 1: position_embeddings - # layer_num 2: token_type_embeddings - continue - elif layer_num == 3: - # embedding LayerNorm - trace.extend(["embeddings", "LayerNorm"]) - pointer = getattr(pointer, "embeddings") - pointer = getattr(pointer, "LayerNorm") - elif layer_num > 3 and layer_num < config.num_hidden_layers + 4: - # encoder layers - trace.extend(["encoder", "layer", str(layer_num - 4)]) - pointer = getattr(pointer, "encoder") - pointer = getattr(pointer, "layer") - pointer = pointer[layer_num - 4] - elif layer_num == config.num_hidden_layers + 4: - # pooler layer - trace.extend(["pooler", "dense"]) - pointer = getattr(pointer, "pooler") - pointer = getattr(pointer, "dense") - elif m_name == "embeddings": - trace.append("embeddings") - pointer = getattr(pointer, "embeddings") - if layer_num == 0: - trace.append("word_embeddings") - pointer = getattr(pointer, "word_embeddings") - elif layer_num == 1: - trace.append("position_embeddings") - pointer = getattr(pointer, "position_embeddings") - elif layer_num == 2: - trace.append("token_type_embeddings") - pointer = getattr(pointer, "token_type_embeddings") - else: - raise ValueError(f"Unknown embedding layer with name {full_name}") - trace.append("weight") - pointer = getattr(pointer, "weight") - elif m_name == "_attention_layer": - # self-attention layer - trace.extend(["attention", "self"]) - pointer = getattr(pointer, "attention") - pointer = getattr(pointer, "self") - elif m_name == "_attention_layer_norm": - # output attention norm - trace.extend(["attention", "output", "LayerNorm"]) - pointer = getattr(pointer, "attention") - pointer = getattr(pointer, "output") - pointer = getattr(pointer, "LayerNorm") - elif m_name == "_attention_output_dense": - # output attention dense - trace.extend(["attention", "output", "dense"]) - pointer = getattr(pointer, "attention") - pointer = getattr(pointer, "output") - pointer = getattr(pointer, "dense") - elif m_name == "_output_dense": - # output dense - trace.extend(["output", "dense"]) - pointer = getattr(pointer, "output") - pointer = getattr(pointer, "dense") - elif m_name == "_output_layer_norm": - # output dense - trace.extend(["output", "LayerNorm"]) - pointer = getattr(pointer, "output") - pointer = getattr(pointer, "LayerNorm") - elif m_name == "_key_dense": - # attention key - trace.append("key") - pointer = getattr(pointer, "key") - elif m_name == "_query_dense": - # attention query - trace.append("query") - pointer = getattr(pointer, "query") - elif m_name == "_value_dense": - # attention value - trace.append("value") - pointer = getattr(pointer, "value") - elif m_name == "_intermediate_dense": - # attention intermediate dense - trace.extend(["intermediate", "dense"]) - pointer = getattr(pointer, "intermediate") - pointer = getattr(pointer, "dense") - elif m_name == "_output_layer_norm": - # output layer norm - trace.append("output") - pointer = getattr(pointer, "output") - # weights & biases - elif m_name in ["bias", "beta"]: - trace.append("bias") - pointer = getattr(pointer, "bias") - elif m_name in ["kernel", "gamma"]: - trace.append("weight") - pointer = getattr(pointer, "weight") - else: - logger.warning(f"Ignored {m_name}") - # for certain layers reshape is necessary - trace = ".".join(trace) - if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match( - r"(\S+)\.attention\.output\.dense\.weight", trace - ): - array = array.reshape(pointer.data.shape) - if "kernel" in full_name: - array = array.transpose() - if pointer.shape == array.shape: - pointer.data = torch.from_numpy(array) - else: - raise ValueError( - f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:" - f" {array.shape}" - ) - logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}") - return model - - -def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path): - # Instantiate model - logger.info(f"Loading model based on config from {config_path}...") - config = BertConfig.from_json_file(config_path) - model = BertModel(config) - - # Load weights from checkpoint - logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...") - load_tf2_weights_in_bert(model, tf_checkpoint_path, config) - - # Save pytorch-model - logger.info(f"Saving PyTorch model to {pytorch_dump_path}...") - torch.save(model.state_dict(), pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path." - ) - parser.add_argument( - "--bert_config_file", - type=str, - required=True, - help="The config json file corresponding to the BERT model. This specifies the model architecture.", - ) - parser.add_argument( - "--pytorch_dump_path", - type=str, - required=True, - help="Path to the output PyTorch model (must include filename).", - ) - args = parser.parse_args() - convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py deleted file mode 100755 index be904ddd7e6c..000000000000 --- a/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,62 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert BERT checkpoint.""" - -import argparse - -import torch - -from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): - # Initialise PyTorch model - config = BertConfig.from_json_file(bert_config_file) - print(f"Building PyTorch model from configuration: {config}") - model = BertForPreTraining(config) - - # Load weights from tf checkpoint - load_tf_weights_in_bert(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - torch.save(model.state_dict(), pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--bert_config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained BERT model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py b/src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py deleted file mode 100644 index 8e1e85d5c04e..000000000000 --- a/src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py +++ /dev/null @@ -1,112 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" - -import argparse -import os - -import numpy as np -import tensorflow as tf -import torch - -from transformers import BertModel - - -def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str): - """ - Args: - model: BertModel Pytorch model instance to be converted - ckpt_dir: Tensorflow model directory - model_name: model name - - Currently supported HF models: - - - Y BertModel - - N BertForMaskedLM - - N BertForPreTraining - - N BertForMultipleChoice - - N BertForNextSentencePrediction - - N BertForSequenceClassification - - N BertForQuestionAnswering - """ - - tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value") - - var_map = ( - ("layer.", "layer_"), - ("word_embeddings.weight", "word_embeddings"), - ("position_embeddings.weight", "position_embeddings"), - ("token_type_embeddings.weight", "token_type_embeddings"), - (".", "/"), - ("LayerNorm/weight", "LayerNorm/gamma"), - ("LayerNorm/bias", "LayerNorm/beta"), - ("weight", "kernel"), - ) - - if not os.path.isdir(ckpt_dir): - os.makedirs(ckpt_dir) - - state_dict = model.state_dict() - - def to_tf_var_name(name: str): - for patt, repl in iter(var_map): - name = name.replace(patt, repl) - return f"bert/{name}" - - def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session): - tf_dtype = tf.dtypes.as_dtype(tensor.dtype) - tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) - session.run(tf.variables_initializer([tf_var])) - session.run(tf_var) - return tf_var - - tf.reset_default_graph() - with tf.Session() as session: - for var_name in state_dict: - tf_name = to_tf_var_name(var_name) - torch_tensor = state_dict[var_name].numpy() - if any(x in var_name for x in tensors_to_transpose): - torch_tensor = torch_tensor.T - tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) - tf_var.assign(tf.cast(torch_tensor, tf_var.dtype)) - tf_weight = session.run(tf_var) - print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}") - - saver = tf.train.Saver(tf.trainable_variables()) - saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) - - -def main(raw_args=None): - parser = argparse.ArgumentParser() - parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased") - parser.add_argument( - "--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model" - ) - parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/.bin") - parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model") - args = parser.parse_args(raw_args) - - model = BertModel.from_pretrained( - pretrained_model_name_or_path=args.model_name, - state_dict=torch.load(args.pytorch_model_path, weights_only=True), - cache_dir=args.cache_dir, - ) - - convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name) - - -if __name__ == "__main__": - main() diff --git a/src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py b/src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py deleted file mode 100644 index a7832a53d55d..000000000000 --- a/src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT -model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository: - -https://github.com/tensorflow/models/tree/master/official/projects/token_dropping -""" - -import argparse - -import tensorflow as tf -import torch - -from transformers import BertConfig, BertForMaskedLM -from transformers.models.bert.modeling_bert import ( - BertIntermediate, - BertLayer, - BertOutput, - BertPooler, - BertSelfAttention, - BertSelfOutput, -) -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str): - def get_masked_lm_array(name: str): - full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE" - array = tf.train.load_variable(tf_checkpoint_path, full_name) - - if "kernel" in name: - array = array.transpose() - - return torch.from_numpy(array) - - def get_encoder_array(name: str): - full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE" - array = tf.train.load_variable(tf_checkpoint_path, full_name) - - if "kernel" in name: - array = array.transpose() - - return torch.from_numpy(array) - - def get_encoder_layer_array(layer_index: int, name: str): - full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE" - array = tf.train.load_variable(tf_checkpoint_path, full_name) - - if "kernel" in name: - array = array.transpose() - - return torch.from_numpy(array) - - def get_encoder_attention_layer_array(layer_index: int, name: str, original_shape): - full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE" - array = tf.train.load_variable(tf_checkpoint_path, full_name) - array = array.reshape(original_shape) - - if "kernel" in name: - array = array.transpose() - - return torch.from_numpy(array) - - print(f"Loading model based on config from {config_path}...") - config = BertConfig.from_json_file(config_path) - model = BertForMaskedLM(config) - - # Layers - for layer_index in range(0, config.num_hidden_layers): - layer: BertLayer = model.bert.encoder.layer[layer_index] - - # Self-attention - self_attn: BertSelfAttention = layer.attention.self - - self_attn.query.weight.data = get_encoder_attention_layer_array( - layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape - ) - self_attn.query.bias.data = get_encoder_attention_layer_array( - layer_index, "_query_dense/bias", self_attn.query.bias.data.shape - ) - self_attn.key.weight.data = get_encoder_attention_layer_array( - layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape - ) - self_attn.key.bias.data = get_encoder_attention_layer_array( - layer_index, "_key_dense/bias", self_attn.key.bias.data.shape - ) - self_attn.value.weight.data = get_encoder_attention_layer_array( - layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape - ) - self_attn.value.bias.data = get_encoder_attention_layer_array( - layer_index, "_value_dense/bias", self_attn.value.bias.data.shape - ) - - # Self-attention Output - self_output: BertSelfOutput = layer.attention.output - - self_output.dense.weight.data = get_encoder_attention_layer_array( - layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape - ) - self_output.dense.bias.data = get_encoder_attention_layer_array( - layer_index, "_output_dense/bias", self_output.dense.bias.data.shape - ) - - self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma") - self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta") - - # Intermediate - intermediate: BertIntermediate = layer.intermediate - - intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel") - intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias") - - # Output - bert_output: BertOutput = layer.output - - bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel") - bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias") - - bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma") - bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta") - - # Embeddings - model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings") - model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings") - model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma") - model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta") - - # LM Head - lm_head = model.cls.predictions.transform - - lm_head.dense.weight.data = get_masked_lm_array("dense/kernel") - lm_head.dense.bias.data = get_masked_lm_array("dense/bias") - - lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma") - lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta") - - model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table") - - # Pooling - model.bert.pooler = BertPooler(config=config) - model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel") - model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias") - - # Export final model - model.save_pretrained(pytorch_dump_path) - - # Integration test - should load without any errors ;) - new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path) - print(new_model.eval()) - - print("Model conversion was done successfully!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path." - ) - parser.add_argument( - "--bert_config_file", - type=str, - required=True, - help="The config json file corresponding to the BERT model. This specifies the model architecture.", - ) - parser.add_argument( - "--pytorch_dump_path", - type=str, - required=True, - help="Path to the output PyTorch model.", - ) - args = parser.parse_args() - convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py deleted file mode 100644 index 37828642eb4e..000000000000 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ /dev/null @@ -1,1727 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Optional - -import flax -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen import partitioning as nn_partitioning -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxBaseModelOutputWithPooling, - FlaxBaseModelOutputWithPoolingAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxMaskedLMOutput, - FlaxMultipleChoiceModelOutput, - FlaxNextSentencePredictorOutput, - FlaxQuestionAnsweringModelOutput, - FlaxSequenceClassifierOutput, - FlaxTokenClassifierOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_bert import BertConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased" -_CONFIG_FOR_DOC = "BertConfig" - -remat = nn_partitioning.remat - - -@flax.struct.dataclass -class FlaxBertForPreTrainingOutput(ModelOutput): - """ - Output type of [`BertForPreTraining`]. - - Args: - prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - seq_relationship_logits (`jnp.ndarray` of shape `(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - prediction_logits: jnp.ndarray = None - seq_relationship_logits: jnp.ndarray = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -BERT_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`BertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. - -""" - -BERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - head_mask (`numpy.ndarray` of shape `({0})`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - -""" - - -class FlaxBertEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" - - config: BertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.position_embeddings = nn.Embed( - self.config.max_position_embeddings, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.token_type_embeddings = nn.Embed( - self.config.type_vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): - # Embed - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) - position_embeds = self.position_embeddings(position_ids.astype("i4")) - token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) - - # Sum all embeddings - hidden_states = inputs_embeds + token_type_embeddings + position_embeds - - # Layer Norm - hidden_states = self.LayerNorm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxBertSelfAttention(nn.Module): - config: BertConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.head_dim = self.config.hidden_size // self.config.num_attention_heads - if self.config.hidden_size % self.config.num_attention_heads != 0: - raise ValueError( - "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " - " : {self.config.num_attention_heads}" - ) - - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) - - @nn.compact - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic=True, - output_attentions: bool = False, - ): - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.query(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.key(key_value_states) - value_states = self.value(key_value_states) - else: - # self_attention - key_states = self.key(hidden_states) - value_states = self.value(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxBertSelfOutput(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, input_tensor, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class FlaxBertAttention(nn.Module): - config: BertConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.self = FlaxBertSelfAttention(self.config, causal=self.causal, dtype=self.dtype) - self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states=None, - init_cache=False, - deterministic=True, - output_attentions: bool = False, - ): - # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) - # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable - # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - attn_outputs = self.self( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=key_value_states, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_outputs[1],) - - return outputs - - -class FlaxBertIntermediate(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.activation = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -class FlaxBertOutput(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states, attention_output, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + attention_output) - return hidden_states - - -class FlaxBertLayer(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.attention = FlaxBertAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) - self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype) - self.output = FlaxBertOutput(self.config, dtype=self.dtype) - if self.config.add_cross_attention: - self.crossattention = FlaxBertAttention(self.config, causal=False, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - ): - # Self Attention - attention_outputs = self.attention( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = attention_outputs[0] - - # Cross-Attention Block - if encoder_hidden_states is not None: - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=encoder_hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = cross_attention_outputs[0] - - hidden_states = self.intermediate(attention_output) - hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[1],) - if encoder_hidden_states is not None: - outputs += (cross_attention_outputs[1],) - return outputs - - -class FlaxBertLayerCollection(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - if self.gradient_checkpointing: - FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7)) - self.layers = [ - FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - else: - self.layers = [ - FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - # Check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.shape[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for " - f" {head_mask.shape[0]}." - ) - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer( - hidden_states, - attention_mask, - head_mask[i] if head_mask is not None else None, - encoder_hidden_states, - encoder_attention_mask, - init_cache, - deterministic, - output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -class FlaxBertEncoder(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.layer = FlaxBertLayerCollection( - self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.layer( - hidden_states, - attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class FlaxBertPooler(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - cls_hidden_state = hidden_states[:, 0] - cls_hidden_state = self.dense(cls_hidden_state) - return nn.tanh(cls_hidden_state) - - -class FlaxBertPredictionHeadTransform(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) - self.activation = ACT2FN[self.config.hidden_act] - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return self.LayerNorm(hidden_states) - - -class FlaxBertLMPredictionHead(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype) - self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) - self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) - - def __call__(self, hidden_states, shared_embedding=None): - hidden_states = self.transform(hidden_states) - - if shared_embedding is not None: - hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - hidden_states = self.decoder(hidden_states) - - bias = jnp.asarray(self.bias, self.dtype) - hidden_states += bias - return hidden_states - - -class FlaxBertOnlyMLMHead(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) - - def __call__(self, hidden_states, shared_embedding=None): - hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding) - return hidden_states - - -class FlaxBertOnlyNSPHead(nn.Module): - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.seq_relationship = nn.Dense(2, dtype=self.dtype) - - def __call__(self, pooled_output): - return self.seq_relationship(pooled_output) - - -class FlaxBertPreTrainingHeads(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) - self.seq_relationship = nn.Dense(2, dtype=self.dtype) - - def __call__(self, hidden_states, pooled_output, shared_embedding=None): - prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class FlaxBertPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BertConfig - base_model_prefix = "bert" - module_class: nn.Module = None - - def __init__( - self, - config: BertConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - gradient_checkpointing: bool = False, - **kwargs, - ): - module = self.module_class( - config=config, - dtype=dtype, - gradient_checkpointing=gradient_checkpointing, - **kwargs, - ) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def enable_gradient_checkpointing(self): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=True, - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - token_type_ids = jnp.zeros_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - attention_mask = jnp.ones_like(input_ids) - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False - ) - - random_params = module_init_outputs["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length), dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - past_key_values: Optional[dict] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # init input tensors if not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - if head_mask is None: - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - if self.config.add_cross_attention: - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed - # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be - # changed by FlaxBertAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - else: - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - ) - - return outputs - - -class FlaxBertModule(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - add_pooling_layer: bool = True - gradient_checkpointing: bool = False - - def setup(self): - self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxBertEncoder( - self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # make sure `token_type_ids` is correctly initialized when not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - # make sure `position_ids` is correctly initialized when not passed - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - hidden_states = self.embeddings( - input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic - ) - outputs = self.encoder( - hidden_states, - attention_mask, - head_mask=head_mask, - deterministic=deterministic, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - pooled = self.pooler(hidden_states) if self.add_pooling_layer else None - - if not return_dict: - # if pooled is None, don't return it - if pooled is None: - return (hidden_states,) + outputs[1:] - return (hidden_states, pooled) + outputs[1:] - - return FlaxBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=hidden_states, - pooler_output=pooled, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", - BERT_START_DOCSTRING, -) -class FlaxBertModel(FlaxBertPreTrainedModel): - module_class = FlaxBertModule - - -append_call_sample_docstring(FlaxBertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC) - - -class FlaxBertForPreTrainingModule(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBertModule( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if self.config.tie_word_embeddings: - shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - hidden_states = outputs[0] - pooled_output = outputs[1] - - prediction_scores, seq_relationship_score = self.cls( - hidden_states, pooled_output, shared_embedding=shared_embedding - ) - - if not return_dict: - return (prediction_scores, seq_relationship_score) + outputs[2:] - - return FlaxBertForPreTrainingOutput( - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next - sentence prediction (classification)` head. - """, - BERT_START_DOCSTRING, -) -class FlaxBertForPreTraining(FlaxBertPreTrainedModel): - module_class = FlaxBertForPreTrainingModule - - -FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxBertForPreTraining - - >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") - >>> model = FlaxBertForPreTraining.from_pretrained("google-bert/bert-base-uncased") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.prediction_logits - >>> seq_relationship_logits = outputs.seq_relationship_logits - ``` -""" - -overwrite_call_docstring( - FlaxBertForPreTraining, - BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_PRETRAINING_DOCSTRING, -) -append_replace_return_docstrings( - FlaxBertForPreTraining, output_type=FlaxBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC -) - - -class FlaxBertForMaskedLMModule(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBertModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.cls(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxMaskedLMOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) -class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): - module_class = FlaxBertForMaskedLMModule - - -append_call_sample_docstring(FlaxBertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC) - - -class FlaxBertForNextSentencePredictionModule(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBertModule( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - seq_relationship_scores = self.cls(pooled_output) - - if not return_dict: - return (seq_relationship_scores,) + outputs[2:] - - return FlaxNextSentencePredictorOutput( - logits=seq_relationship_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """Bert Model with a `next sentence prediction (classification)` head on top.""", - BERT_START_DOCSTRING, -) -class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel): - module_class = FlaxBertForNextSentencePredictionModule - - -FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxBertForNextSentencePrediction - - >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") - >>> model = FlaxBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased") - - >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." - >>> encoding = tokenizer(prompt, next_sentence, return_tensors="jax") - - >>> outputs = model(**encoding) - >>> logits = outputs.logits - >>> assert logits[0, 0] < logits[0, 1] # next sentence was random - ``` -""" - - -overwrite_call_docstring( - FlaxBertForNextSentencePrediction, - BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING, -) -append_replace_return_docstrings( - FlaxBertForNextSentencePrediction, output_type=FlaxNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC -) - - -class FlaxBertForSequenceClassificationModule(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBertModule( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(rate=classifier_dropout) - self.classifier = nn.Dense( - self.config.num_labels, - dtype=self.dtype, - ) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) - - if not return_dict: - return (logits,) + outputs[2:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled - output) e.g. for GLUE tasks. - """, - BERT_START_DOCSTRING, -) -class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel): - module_class = FlaxBertForSequenceClassificationModule - - -append_call_sample_docstring( - FlaxBertForSequenceClassification, - _CHECKPOINT_FOR_DOC, - FlaxSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxBertForMultipleChoiceModule(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBertModule( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.classifier = nn.Dense(1, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - num_choices = input_ids.shape[1] - input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None - token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None - position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None - - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) - - reshaped_logits = logits.reshape(-1, num_choices) - - if not return_dict: - return (reshaped_logits,) + outputs[2:] - - return FlaxMultipleChoiceModelOutput( - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - BERT_START_DOCSTRING, -) -class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel): - module_class = FlaxBertForMultipleChoiceModule - - -overwrite_call_docstring( - FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") -) -append_call_sample_docstring( - FlaxBertForMultipleChoice, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC -) - - -class FlaxBertForTokenClassificationModule(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBertModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(rate=classifier_dropout) - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.classifier(hidden_states) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxTokenClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - BERT_START_DOCSTRING, -) -class FlaxBertForTokenClassification(FlaxBertPreTrainedModel): - module_class = FlaxBertForTokenClassificationModule - - -append_call_sample_docstring( - FlaxBertForTokenClassification, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC -) - - -class FlaxBertForQuestionAnsweringModule(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBertModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - logits = self.qa_outputs(hidden_states) - start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - return (start_logits, end_logits) + outputs[1:] - - return FlaxQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - BERT_START_DOCSTRING, -) -class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel): - module_class = FlaxBertForQuestionAnsweringModule - - -append_call_sample_docstring( - FlaxBertForQuestionAnswering, - _CHECKPOINT_FOR_DOC, - FlaxQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxBertForCausalLMModule(nn.Module): - config: BertConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBertModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - token_type_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.cls(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxCausalLMOutputWithCrossAttentions( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - """ - Bert Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for - autoregressive tasks. - """, - BERT_START_DOCSTRING, -) -class FlaxBertForCausalLM(FlaxBertPreTrainedModel): - module_class = FlaxBertForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyway. - # Thus, we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxBertForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutputWithCrossAttentions, - _CONFIG_FOR_DOC, -) - - -__all__ = [ - "FlaxBertForCausalLM", - "FlaxBertForMaskedLM", - "FlaxBertForMultipleChoice", - "FlaxBertForNextSentencePrediction", - "FlaxBertForPreTraining", - "FlaxBertForQuestionAnswering", - "FlaxBertForSequenceClassification", - "FlaxBertForTokenClassification", - "FlaxBertModel", - "FlaxBertPreTrainedModel", -] diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py deleted file mode 100644 index 1ca82f9f1820..000000000000 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ /dev/null @@ -1,2125 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 BERT model.""" - -from __future__ import annotations - -import math -import warnings -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFBaseModelOutputWithPoolingAndCrossAttentions, - TFCausalLMOutputWithCrossAttentions, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFNextSentencePredictorOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFNextSentencePredictionLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_bert import BertConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased" -_CONFIG_FOR_DOC = "BertConfig" - -# TokenClassification docstring -_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" -_TOKEN_CLASS_EXPECTED_OUTPUT = ( - "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " -) -_TOKEN_CLASS_EXPECTED_LOSS = 0.01 - -# QuestionAnswering docstring -_CHECKPOINT_FOR_QA = "ydshieh/bert-base-cased-squad2" -_QA_EXPECTED_OUTPUT = "'a nice puppet'" -_QA_EXPECTED_LOSS = 7.41 -_QA_TARGET_START_INDEX = 14 -_QA_TARGET_END_INDEX = 15 - -# SequenceClassification docstring -_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ydshieh/bert-base-uncased-yelp-polarity" -_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" -_SEQ_CLASS_EXPECTED_LOSS = 0.01 - - -class TFBertPreTrainingLoss: - """ - Loss function suitable for BERT-like pretraining, that is, the task of pretraining a language model by combining - NSP + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss - computation. - """ - - def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor: - loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) - - # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway - unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0]) - # make sure only labels that are not equal to -100 - # are taken into account for the loss computation - lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype) - masked_lm_losses = unmasked_lm_losses * lm_loss_mask - reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask) - - # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway - unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels["next_sentence_label"]), y_pred=logits[1]) - ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype) - masked_ns_loss = unmasked_ns_loss * ns_loss_mask - - reduced_masked_ns_loss = tf.reduce_sum(masked_ns_loss) / tf.reduce_sum(ns_loss_mask) - - return tf.reshape(reduced_masked_lm_loss + reduced_masked_ns_loss, (1,)) - - -class TFBertEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config: BertConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - def call( - self, - input_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - past_key_values_length=0, - training: bool = False, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - if input_ids is None and inputs_embeds is None: - raise ValueError("Need to provide either `input_ids` or `input_embeds`.") - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - position_ids = tf.expand_dims( - tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 - ) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -class TFBertSelfAttention(keras.layers.Layer): - def __init__(self, config: BertConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - - self.is_decoder = config.is_decoder - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concat([past_key_value[0], key_layer], axis=2) - value_layer = tf.concat([past_key_value[1], value_layer], axis=2) - else: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFBertModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -class TFBertSelfOutput(keras.layers.Layer): - def __init__(self, config: BertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFBertAttention(keras.layers.Layer): - def __init__(self, config: BertConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFBertSelfAttention(config, name="self") - self.dense_output = TFBertSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - # add attentions (possibly with past_key_value) if we output them - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -class TFBertIntermediate(keras.layers.Layer): - def __init__(self, config: BertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFBertOutput(keras.layers.Layer): - def __init__(self, config: BertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFBertLayer(keras.layers.Layer): - def __init__(self, config: BertConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFBertAttention(config, name="attention") - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TFBertAttention(config, name="crossattention") - self.intermediate = TFBertIntermediate(config, name="intermediate") - self.bert_output = TFBertOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_value: tuple[tf.Tensor] | None, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=self_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - input_tensor=attention_output, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + outputs # add attentions if we output them - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - if getattr(self, "crossattention", None) is not None: - with tf.name_scope(self.crossattention.name): - self.crossattention.build(None) - - -class TFBertEncoder(keras.layers.Layer): - def __init__(self, config: BertConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layer = [TFBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_values: tuple[tuple[tf.Tensor]] | None, - use_cache: bool | None, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - past_key_value = past_key_values[i] if past_key_values is not None else None - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - if self.config.add_cross_attention and encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None - ) - - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFBertPooler(keras.layers.Layer): - def __init__(self, config: BertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFBertPredictionHeadTransform(keras.layers.Layer): - def __init__(self, config: BertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - - if isinstance(config.hidden_act, str): - self.transform_act_fn = get_tf_activation(config.hidden_act) - else: - self.transform_act_fn = config.hidden_act - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(inputs=hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFBertLMPredictionHead(keras.layers.Layer): - def __init__(self, config: BertConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - - self.transform = TFBertPredictionHeadTransform(config, name="transform") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.input_embeddings = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "transform", None) is not None: - with tf.name_scope(self.transform.name): - self.transform.build(None) - - def get_output_embeddings(self) -> keras.layers.Layer: - return self.input_embeddings - - def set_output_embeddings(self, value: tf.Variable): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self) -> dict[str, tf.Variable]: - return {"bias": self.bias} - - def set_bias(self, value: tf.Variable): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.transform(hidden_states=hidden_states) - seq_length = shape_list(hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -class TFBertMLMHead(keras.layers.Layer): - def __init__(self, config: BertConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions") - - def call(self, sequence_output: tf.Tensor) -> tf.Tensor: - prediction_scores = self.predictions(hidden_states=sequence_output) - - return prediction_scores - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - -class TFBertNSPHead(keras.layers.Layer): - def __init__(self, config: BertConfig, **kwargs): - super().__init__(**kwargs) - - self.seq_relationship = keras.layers.Dense( - units=2, - kernel_initializer=get_initializer(config.initializer_range), - name="seq_relationship", - ) - self.config = config - - def call(self, pooled_output: tf.Tensor) -> tf.Tensor: - seq_relationship_score = self.seq_relationship(inputs=pooled_output) - - return seq_relationship_score - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "seq_relationship", None) is not None: - with tf.name_scope(self.seq_relationship.name): - self.seq_relationship.build([None, None, self.config.hidden_size]) - - -@keras_serializable -class TFBertMainLayer(keras.layers.Layer): - config_class = BertConfig - - def __init__(self, config: BertConfig, add_pooling_layer: bool = True, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.is_decoder = config.is_decoder - - self.embeddings = TFBertEmbeddings(config, name="embeddings") - self.encoder = TFBertEncoder(config, name="encoder") - self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - if not self.config.is_decoder: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_key_values_length = 0 - past_key_values = [None] * len(self.encoder.layer) - else: - past_key_values_length = shape_list(past_key_values[0][0])[-2] - - if attention_mask is None: - attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - training=training, - ) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(attention_mask) - - mask_seq_length = seq_length + past_key_values_length - # Copied from `modeling_tf_t5.py` - # Provided a padding mask of dimensions [batch_size, mask_seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if self.is_decoder: - seq_ids = tf.range(mask_seq_length) - causal_mask = tf.less_equal( - tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), - seq_ids[None, :, None], - ) - causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) - extended_attention_mask = causal_mask * attention_mask[:, None, :] - attention_mask_shape = shape_list(extended_attention_mask) - extended_attention_mask = tf.reshape( - extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) - ) - if past_key_values[0] is not None: - # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] - extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] - else: - extended_attention_mask = tf.reshape( - attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.is_decoder and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) - num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) - if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, - # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) - - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -class TFBertPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BertConfig - base_model_prefix = "bert" - - -@dataclass -class TFBertForPreTrainingOutput(ModelOutput): - """ - Output type of [`TFBertForPreTraining`]. - - Args: - prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - seq_relationship_logits (`tf.Tensor` of shape `(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - prediction_logits: tf.Tensor | None = None - seq_relationship_logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | tf.Tensor | None = None - attentions: tuple[tf.Tensor] | tf.Tensor | None = None - - -BERT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`BertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -BERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", - BERT_START_DOCSTRING, -) -class TFBertModel(TFBertPreTrainedModel): - def __init__(self, config: BertConfig, add_pooling_layer: bool = True, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.bert = TFBertMainLayer(config, add_pooling_layer, name="bert") - - @unpack_inputs - @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - """ - outputs = self.bert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bert", None) is not None: - with tf.name_scope(self.bert.name): - self.bert.build(None) - - -@add_start_docstrings( - """ -Bert Model with two heads on top as done during the pretraining: - a `masked language modeling` head and a `next sentence prediction (classification)` head. - """, - BERT_START_DOCSTRING, -) -class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"position_ids", - r"cls.predictions.decoder.weight", - r"cls.predictions.decoder.bias", - ] - - def __init__(self, config: BertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.bert = TFBertMainLayer(config, name="bert") - self.nsp = TFBertNSPHead(config, name="nsp___cls") - self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls") - - def get_lm_head(self) -> keras.layers.Layer: - return self.mlm.predictions - - def get_prefix_bias_name(self) -> str: - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - next_sentence_label: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFBertForPreTrainingOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - next_sentence_label (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair - (see `input_ids` docstring) Indices should be in `[0, 1]`: - - - 0 indicates sequence B is a continuation of sequence A, - - 1 indicates sequence B is a random sequence. - kwargs (`dict[str, any]`, *optional*, defaults to `{}`): - Used to hide legacy arguments that have been deprecated. - - Return: - - Examples: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFBertForPreTraining - - >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") - >>> model = TFBertForPreTraining.from_pretrained("google-bert/bert-base-uncased") - >>> input_ids = tokenizer("Hello, my dog is cute", add_special_tokens=True, return_tensors="tf") - >>> # Batch size 1 - - >>> outputs = model(input_ids) - >>> prediction_logits, seq_relationship_logits = outputs[:2] - ```""" - outputs = self.bert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output, pooled_output = outputs[:2] - prediction_scores = self.mlm(sequence_output=sequence_output, training=training) - seq_relationship_score = self.nsp(pooled_output=pooled_output) - total_loss = None - - if labels is not None and next_sentence_label is not None: - d_labels = {"labels": labels} - d_labels["next_sentence_label"] = next_sentence_label - total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score)) - - if not return_dict: - output = (prediction_scores, seq_relationship_score) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return TFBertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bert", None) is not None: - with tf.name_scope(self.bert.name): - self.bert.build(None) - if getattr(self, "nsp", None) is not None: - with tf.name_scope(self.nsp.name): - self.nsp.build(None) - if getattr(self, "mlm", None) is not None: - with tf.name_scope(self.mlm.name): - self.mlm.build(None) - - -@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) -class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"pooler", - r"cls.seq_relationship", - r"cls.predictions.decoder.weight", - r"nsp___cls", - ] - - def __init__(self, config: BertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if config.is_decoder: - logger.warning( - "If you want to use `TFBertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") - self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls") - - def get_lm_head(self) -> keras.layers.Layer: - return self.mlm.predictions - - def get_prefix_bias_name(self) -> str: - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="'paris'", - expected_loss=0.88, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - outputs = self.bert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.mlm(sequence_output=sequence_output, training=training) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bert", None) is not None: - with tf.name_scope(self.bert.name): - self.bert.build(None) - if getattr(self, "mlm", None) is not None: - with tf.name_scope(self.mlm.name): - self.mlm.build(None) - - -class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"pooler", - r"cls.seq_relationship", - r"cls.predictions.decoder.weight", - r"nsp___cls", - ] - - def __init__(self, config: BertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if not config.is_decoder: - logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`") - - self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") - self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls") - - def get_lm_head(self) -> keras.layers.Layer: - return self.mlm.predictions - - def get_prefix_bias_name(self) -> str: - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = tf.ones(input_shape) - - # cut decoder_input_ids if past is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - - @unpack_inputs - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - **kwargs, - ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - outputs = self.bert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.mlm(sequence_output=sequence_output, training=training) - loss = None - - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithCrossAttentions( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bert", None) is not None: - with tf.name_scope(self.bert.name): - self.bert.build(None) - if getattr(self, "mlm", None) is not None: - with tf.name_scope(self.mlm.name): - self.mlm.build(None) - - -@add_start_docstrings( - """Bert Model with a `next sentence prediction (classification)` head on top.""", - BERT_START_DOCSTRING, -) -class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"cls.predictions"] - - def __init__(self, config: BertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.bert = TFBertMainLayer(config, name="bert") - self.nsp = TFBertNSPHead(config, name="nsp___cls") - - @unpack_inputs - @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - next_sentence_label: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFNextSentencePredictorOutput | tuple[tf.Tensor]: - r""" - Return: - - Examples: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFBertForNextSentencePrediction - - >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") - >>> model = TFBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased") - - >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." - >>> encoding = tokenizer(prompt, next_sentence, return_tensors="tf") - - >>> logits = model(encoding["input_ids"], token_type_ids=encoding["token_type_ids"])[0] - >>> assert logits[0][0] < logits[0][1] # the next sentence was random - ```""" - outputs = self.bert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - seq_relationship_scores = self.nsp(pooled_output=pooled_output) - next_sentence_loss = ( - None - if next_sentence_label is None - else self.hf_compute_loss(labels=next_sentence_label, logits=seq_relationship_scores) - ) - - if not return_dict: - output = (seq_relationship_scores,) + outputs[2:] - return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output - - return TFNextSentencePredictorOutput( - loss=next_sentence_loss, - logits=seq_relationship_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bert", None) is not None: - with tf.name_scope(self.bert.name): - self.bert.build(None) - if getattr(self, "nsp", None) is not None: - with tf.name_scope(self.nsp.name): - self.nsp.build(None) - - -@add_start_docstrings( - """ - Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled - output) e.g. for GLUE tasks. - """, - BERT_START_DOCSTRING, -) -class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config: BertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.bert = TFBertMainLayer(config, name="bert") - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(rate=classifier_dropout) - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, - expected_loss=_SEQ_CLASS_EXPECTED_LOSS, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.bert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=training) - logits = self.classifier(inputs=pooled_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bert", None) is not None: - with tf.name_scope(self.bert.name): - self.bert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - BERT_START_DOCSTRING, -) -class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config: BertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.bert = TFBertMainLayer(config, name="bert") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None - flat_attention_mask = ( - tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None - ) - flat_token_type_ids = ( - tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None - ) - flat_position_ids = ( - tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None - ) - flat_inputs_embeds = ( - tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - outputs = self.bert( - input_ids=flat_input_ids, - attention_mask=flat_attention_mask, - token_type_ids=flat_token_type_ids, - position_ids=flat_position_ids, - head_mask=head_mask, - inputs_embeds=flat_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=training) - logits = self.classifier(inputs=pooled_output) - reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bert", None) is not None: - with tf.name_scope(self.bert.name): - self.bert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - BERT_START_DOCSTRING, -) -class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"pooler", - r"mlm___cls", - r"nsp___cls", - r"cls.predictions", - r"cls.seq_relationship", - ] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config: BertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(rate=classifier_dropout) - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT, - expected_loss=_TOKEN_CLASS_EXPECTED_LOSS, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.bert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(inputs=sequence_output, training=training) - logits = self.classifier(inputs=sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bert", None) is not None: - with tf.name_scope(self.bert.name): - self.bert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - BERT_START_DOCSTRING, -) -class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"pooler", - r"mlm___cls", - r"nsp___cls", - r"cls.predictions", - r"cls.seq_relationship", - ] - - def __init__(self, config: BertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") - self.qa_outputs = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="qa_outputs", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_QA, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - qa_target_start_index=_QA_TARGET_START_INDEX, - qa_target_end_index=_QA_TARGET_END_INDEX, - expected_output=_QA_EXPECTED_OUTPUT, - expected_loss=_QA_EXPECTED_LOSS, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.bert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.qa_outputs(inputs=sequence_output) - start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) - start_logits = tf.squeeze(input=start_logits, axis=-1) - end_logits = tf.squeeze(input=end_logits, axis=-1) - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bert", None) is not None: - with tf.name_scope(self.bert.name): - self.bert.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFBertEmbeddings", - "TFBertForMaskedLM", - "TFBertForMultipleChoice", - "TFBertForNextSentencePrediction", - "TFBertForPreTraining", - "TFBertForQuestionAnswering", - "TFBertForSequenceClassification", - "TFBertForTokenClassification", - "TFBertLMHeadModel", - "TFBertMainLayer", - "TFBertModel", - "TFBertPreTrainedModel", -] diff --git a/src/transformers/models/bert/tokenization_bert_tf.py b/src/transformers/models/bert/tokenization_bert_tf.py deleted file mode 100644 index c8fca52c4cbf..000000000000 --- a/src/transformers/models/bert/tokenization_bert_tf.py +++ /dev/null @@ -1,259 +0,0 @@ -import os -from typing import Optional, Union - -import tensorflow as tf -from tensorflow_text import BertTokenizer as BertTokenizerLayer -from tensorflow_text import FastBertTokenizer, ShrinkLongestTrimmer, case_fold_utf8, combine_segments, pad_model_inputs - -from ...modeling_tf_utils import keras -from ...utils.import_utils import requires -from .tokenization_bert import BertTokenizer - - -@requires(backends=("tf", "tensorflow_text")) -class TFBertTokenizer(keras.layers.Layer): - """ - This is an in-graph tokenizer for BERT. It should be initialized similarly to other tokenizers, using the - `from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings - from an existing standard tokenizer object. - - In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run - when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options - than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes - straight from `tf.string` inputs to outputs. - - Args: - vocab_list (`list`): - List containing the vocabulary. - do_lower_case (`bool`, *optional*, defaults to `True`): - Whether or not to lowercase the input when tokenizing. - cls_token_id (`str`, *optional*, defaults to `"[CLS]"`): - The classifier token which is used when doing sequence classification (classification of the whole sequence - instead of per-token classification). It is the first token of the sequence when built with special tokens. - sep_token_id (`str`, *optional*, defaults to `"[SEP]"`): - The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for - sequence classification or for a text and a question for question answering. It is also used as the last - token of a sequence built with special tokens. - pad_token_id (`str`, *optional*, defaults to `"[PAD]"`): - The token used for padding, for example when batching sequences of different lengths. - padding (`str`, defaults to `"longest"`): - The type of padding to use. Can be either `"longest"`, to pad only up to the longest sample in the batch, - or `"max_length", to pad all inputs to the maximum length supported by the tokenizer. - truncation (`bool`, *optional*, defaults to `True`): - Whether to truncate the sequence to the maximum length. - max_length (`int`, *optional*, defaults to `512`): - The maximum length of the sequence, used for padding (if `padding` is "max_length") and/or truncation (if - `truncation` is `True`). - pad_to_multiple_of (`int`, *optional*, defaults to `None`): - If set, the sequence will be padded to a multiple of this value. - return_token_type_ids (`bool`, *optional*, defaults to `True`): - Whether to return token_type_ids. - return_attention_mask (`bool`, *optional*, defaults to `True`): - Whether to return the attention_mask. - use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`): - If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer - class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to - TFLite. - """ - - def __init__( - self, - vocab_list: list, - do_lower_case: bool, - cls_token_id: Optional[int] = None, - sep_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - padding: str = "longest", - truncation: bool = True, - max_length: int = 512, - pad_to_multiple_of: Optional[int] = None, - return_token_type_ids: bool = True, - return_attention_mask: bool = True, - use_fast_bert_tokenizer: bool = True, - **tokenizer_kwargs, - ): - super().__init__() - if use_fast_bert_tokenizer: - self.tf_tokenizer = FastBertTokenizer( - vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs - ) - else: - lookup_table = tf.lookup.StaticVocabularyTable( - tf.lookup.KeyValueTensorInitializer( - keys=vocab_list, - key_dtype=tf.string, - values=tf.range(tf.size(vocab_list, out_type=tf.int64), dtype=tf.int64), - value_dtype=tf.int64, - ), - num_oov_buckets=1, - ) - self.tf_tokenizer = BertTokenizerLayer( - lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs - ) - - self.vocab_list = vocab_list - self.do_lower_case = do_lower_case - self.cls_token_id = vocab_list.index("[CLS]") if cls_token_id is None else cls_token_id - self.sep_token_id = vocab_list.index("[SEP]") if sep_token_id is None else sep_token_id - self.pad_token_id = vocab_list.index("[PAD]") if pad_token_id is None else pad_token_id - self.paired_trimmer = ShrinkLongestTrimmer(max_length - 3, axis=1) # Allow room for special tokens - self.max_length = max_length - self.padding = padding - self.truncation = truncation - self.pad_to_multiple_of = pad_to_multiple_of - self.return_token_type_ids = return_token_type_ids - self.return_attention_mask = return_attention_mask - - @classmethod - def from_tokenizer(cls, tokenizer: "PreTrainedTokenizerBase", **kwargs): # noqa: F821 - """ - Initialize a `TFBertTokenizer` from an existing `Tokenizer`. - - Args: - tokenizer (`PreTrainedTokenizerBase`): - The tokenizer to use to initialize the `TFBertTokenizer`. - - Examples: - - ```python - from transformers import AutoTokenizer, TFBertTokenizer - - tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") - tf_tokenizer = TFBertTokenizer.from_tokenizer(tokenizer) - ``` - """ - do_lower_case = kwargs.pop("do_lower_case", None) - do_lower_case = tokenizer.do_lower_case if do_lower_case is None else do_lower_case - cls_token_id = kwargs.pop("cls_token_id", None) - cls_token_id = tokenizer.cls_token_id if cls_token_id is None else cls_token_id - sep_token_id = kwargs.pop("sep_token_id", None) - sep_token_id = tokenizer.sep_token_id if sep_token_id is None else sep_token_id - pad_token_id = kwargs.pop("pad_token_id", None) - pad_token_id = tokenizer.pad_token_id if pad_token_id is None else pad_token_id - - vocab = tokenizer.get_vocab() - vocab = sorted(vocab.items(), key=lambda x: x[1]) - vocab_list = [entry[0] for entry in vocab] - return cls( - vocab_list=vocab_list, - do_lower_case=do_lower_case, - cls_token_id=cls_token_id, - sep_token_id=sep_token_id, - pad_token_id=pad_token_id, - **kwargs, - ) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs): - """ - Instantiate a `TFBertTokenizer` from a pre-trained tokenizer. - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - The name or path to the pre-trained tokenizer. - - Examples: - - ```python - from transformers import TFBertTokenizer - - tf_tokenizer = TFBertTokenizer.from_pretrained("google-bert/bert-base-uncased") - ``` - """ - try: - tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) - except: # noqa: E722 - from .tokenization_bert_fast import BertTokenizerFast - - tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) - return cls.from_tokenizer(tokenizer, **kwargs) - - def unpaired_tokenize(self, texts): - if self.do_lower_case: - texts = case_fold_utf8(texts) - tokens = self.tf_tokenizer.tokenize(texts) - return tokens.merge_dims(1, -1) - - def call( - self, - text, - text_pair=None, - padding=None, - truncation=None, - max_length=None, - pad_to_multiple_of=None, - return_token_type_ids=None, - return_attention_mask=None, - ): - if padding is None: - padding = self.padding - if padding not in ("longest", "max_length"): - raise ValueError("Padding must be either 'longest' or 'max_length'!") - if max_length is not None and text_pair is not None: - # Because we have to instantiate a Trimmer to do it properly - raise ValueError("max_length cannot be overridden at call time when truncating paired texts!") - if max_length is None: - max_length = self.max_length - if truncation is None: - truncation = self.truncation - if pad_to_multiple_of is None: - pad_to_multiple_of = self.pad_to_multiple_of - if return_token_type_ids is None: - return_token_type_ids = self.return_token_type_ids - if return_attention_mask is None: - return_attention_mask = self.return_attention_mask - if not isinstance(text, tf.Tensor): - text = tf.convert_to_tensor(text) - if text_pair is not None and not isinstance(text_pair, tf.Tensor): - text_pair = tf.convert_to_tensor(text_pair) - if text_pair is not None: - if text.shape.rank > 1: - raise ValueError("text argument should not be multidimensional when a text pair is supplied!") - if text_pair.shape.rank > 1: - raise ValueError("text_pair should not be multidimensional!") - if text.shape.rank == 2: - text, text_pair = text[:, 0], text[:, 1] - text = self.unpaired_tokenize(text) - if text_pair is None: # Unpaired text - if truncation: - text = text[:, : max_length - 2] # Allow room for special tokens - input_ids, token_type_ids = combine_segments( - (text,), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id - ) - else: # Paired text - text_pair = self.unpaired_tokenize(text_pair) - if truncation: - text, text_pair = self.paired_trimmer.trim([text, text_pair]) - input_ids, token_type_ids = combine_segments( - (text, text_pair), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id - ) - if padding == "longest": - pad_length = input_ids.bounding_shape(axis=1) - if pad_to_multiple_of is not None: - # No ceiling division in tensorflow, so we negate floordiv instead - pad_length = pad_to_multiple_of * (-tf.math.floordiv(-pad_length, pad_to_multiple_of)) - else: - pad_length = max_length - - input_ids, attention_mask = pad_model_inputs(input_ids, max_seq_length=pad_length, pad_value=self.pad_token_id) - output = {"input_ids": input_ids} - if return_attention_mask: - output["attention_mask"] = attention_mask - if return_token_type_ids: - token_type_ids, _ = pad_model_inputs( - token_type_ids, max_seq_length=pad_length, pad_value=self.pad_token_id - ) - output["token_type_ids"] = token_type_ids - return output - - def get_config(self): - return { - "vocab_list": self.vocab_list, - "do_lower_case": self.do_lower_case, - "cls_token_id": self.cls_token_id, - "sep_token_id": self.sep_token_id, - "pad_token_id": self.pad_token_id, - } - - -__all__ = ["TFBertTokenizer"] diff --git a/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index 0b8e6590f937..000000000000 --- a/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,69 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert BigBird checkpoint.""" - -import argparse - -from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa): - # Initialise PyTorch model - config = BigBirdConfig.from_json_file(big_bird_config_file) - print(f"Building PyTorch model from configuration: {config}") - - if is_trivia_qa: - model = BigBirdForQuestionAnswering(config) - else: - model = BigBirdForPreTraining(config) - - # Load weights from tf checkpoint - load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--big_bird_config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained BERT model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - parser.add_argument( - "--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch( - args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa - ) diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py deleted file mode 100644 index 11dcb30f3d47..000000000000 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ /dev/null @@ -1,2648 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Optional - -import flax -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen import partitioning as nn_partitioning -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxBaseModelOutputWithPooling, - FlaxBaseModelOutputWithPoolingAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxMaskedLMOutput, - FlaxMultipleChoiceModelOutput, - FlaxSequenceClassifierOutput, - FlaxTokenClassifierOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_big_bird import BigBirdConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "google/bigbird-roberta-base" -_CONFIG_FOR_DOC = "BigBirdConfig" - -remat = nn_partitioning.remat - - -@flax.struct.dataclass -class FlaxBigBirdForPreTrainingOutput(ModelOutput): - """ - Output type of [`BigBirdForPreTraining`]. - - Args: - prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - seq_relationship_logits (`jnp.ndarray` of shape `(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - prediction_logits: jnp.ndarray = None - seq_relationship_logits: jnp.ndarray = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxBigBirdForQuestionAnsweringModelOutput(ModelOutput): - """ - Base class for outputs of question answering models. - - Args: - start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Span-start scores (before SoftMax). - end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Span-end scores (before SoftMax). - pooled_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): - pooled_output returned by FlaxBigBirdModel. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - start_logits: jnp.ndarray = None - end_logits: jnp.ndarray = None - pooled_output: jnp.ndarray = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -BIG_BIRD_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`BigBirdConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -BIG_BIRD_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - head_mask (`numpy.ndarray` of shape `({0})`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - -""" - - -class FlaxBigBirdEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" - - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.setup - def setup(self): - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.position_embeddings = nn.Embed( - self.config.max_position_embeddings, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.token_type_embeddings = nn.Embed( - self.config.type_vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): - # Embed - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) - position_embeds = self.position_embeddings(position_ids.astype("i4")) - token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) - - if self.config.rescale_embeddings: - inputs_embeds *= self.config.hidden_size**0.5 - - # Sum all embeddings - hidden_states = inputs_embeds + token_type_embeddings + position_embeds - - # Layer Norm - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->BigBird -class FlaxBigBirdSelfAttention(nn.Module): - config: BigBirdConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.head_dim = self.config.hidden_size // self.config.num_attention_heads - if self.config.hidden_size % self.config.num_attention_heads != 0: - raise ValueError( - "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " - " : {self.config.num_attention_heads}" - ) - - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) - - @nn.compact - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic=True, - output_attentions: bool = False, - ): - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.query(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.key(key_value_states) - value_states = self.value(key_value_states) - else: - # self_attention - key_states = self.key(hidden_states) - value_states = self.value(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxBigBirdBlockSparseAttention(nn.Module): - config: BigBirdConfig - block_sparse_seed: Optional[int] = None - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - use_bias=self.config.use_bias, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - use_bias=self.config.use_bias, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - use_bias=self.config.use_bias, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - @staticmethod - def transpose_for_scores(x, n_heads, head_size): - new_x_shape = x.shape[:-1] + (n_heads, head_size) - x = x.reshape(*new_x_shape) - return jnp.transpose(x, axes=(0, 2, 1, 3)) - - def __call__( - self, - hidden_states, - attention_mask, - deterministic=True, - output_attentions=False, - ): - n_heads = self.config.num_attention_heads - head_size = self.config.hidden_size // n_heads - - blocked_encoder_mask, band_mask, from_mask, to_mask = self.create_masks_for_block_sparse_attn( - attention_mask, self.config.block_size - ) - - query_layer = self.transpose_for_scores(self.query(hidden_states), n_heads, head_size) - key_layer = self.transpose_for_scores(self.key(hidden_states), n_heads, head_size) - value_layer = self.transpose_for_scores(self.value(hidden_states), n_heads, head_size) - - indices_prng_key = None - if not deterministic: - indices_prng_key = self.make_rng("indices") - - attn_output, attn_weights = self.bigbird_block_sparse_attention( - query_layer, - key_layer, - value_layer, - band_mask, - from_mask, - to_mask, - blocked_encoder_mask, - blocked_encoder_mask, - n_heads, - head_size, - indices_prng_key=indices_prng_key, - deterministic=deterministic, - plan_from_length=None, - plan_num_rand_blocks=None, - output_attentions=output_attentions, - ) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - @staticmethod - def create_masks_for_block_sparse_attn(attention_mask, block_size: int): - batch_size, seq_length = attention_mask.shape - if seq_length % block_size != 0: - raise ValueError( - f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block" - f" size is {block_size}." - ) - - def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask): - """ - Create 3D attention mask from a 2D tensor mask. - - Args: - from_blocked_mask: 2D Tensor of shape [batch_size, - from_seq_length//from_block_size, from_block_size]. - to_blocked_mask: int32 Tensor of shape [batch_size, - to_seq_length//to_block_size, to_block_size]. - - Returns: - float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, - 3*to_block_size]. - """ - exp_blocked_to_pad = jnp.concatenate( - [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1]], axis=2 - ) - band_mask = jnp.einsum("blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad) - band_mask = jnp.expand_dims(band_mask, 1) - return band_mask - - blocked_encoder_mask = attention_mask.reshape(batch_size, seq_length // block_size, block_size) - band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask) - - from_mask = attention_mask.reshape(batch_size, 1, seq_length, 1) - to_mask = attention_mask.reshape(batch_size, 1, 1, seq_length) - - return blocked_encoder_mask, band_mask, from_mask, to_mask - - def bigbird_block_sparse_attention( - self, - query_layer, - key_layer, - value_layer, - band_mask, - from_mask, - to_mask, - from_blocked_mask, - to_blocked_mask, - n_heads, - head_size, - indices_prng_key: Optional[jax.random.PRNGKey] = None, - deterministic: Optional[bool] = True, - plan_from_length=None, - plan_num_rand_blocks=None, - output_attentions=None, - ): - # BigBird block-sparse attention as suggested in paper - - # ITC: - # global tokens: 2 x block_size - # window tokens: 3 x block_size - # random tokens: num_rand_tokens x block_size - - # ETC: - # global tokens: extra_globals_tokens + 2 x block_size - # window tokens: 3 x block_size - # random tokens: num_rand_tokens x block_size - - # Note: - # 1) Currently, ETC is not supported. - # 2) Window size is fixed to 3 blocks & it can be changed only by - # changing `block_size`. - # 3) Number of global blocks are fixed (2 blocks here) & global tokens can be - # controlled only by `block_size`. - - # attention is calculated separately for q[0], q[1], q[2:-2], q[-2], q[-1] in order to use special trick of - # shifting tokens (for calculating sliding attention). hence following code can be divided into 5 parts. - - bsz, _, from_seq_len, _ = query_layer.shape - to_seq_len = key_layer.shape[2] - from_block_size = to_block_size = self.config.block_size - - if from_seq_len % from_block_size != 0: - raise ValueError("Query sided sequence length must be multiple of block size") - - if to_seq_len % to_block_size != 0: - raise ValueError("Key/Value sided sequence length must be multiple of block size") - - if from_seq_len // from_block_size != to_seq_len // to_block_size: - raise ValueError("Error the number of blocks needs to be same!") - - n_rand_blocks = self.config.num_random_blocks - rsqrt_d = 1 / jnp.sqrt(head_size) - attn_mask_penalty = -10000.0 - - if from_seq_len in [1024, 3072, 4096]: # old plans used in paper - max_seqlen = self.config.max_position_embeddings - rand_attn = [ - self._bigbird_block_rand_mask( - max_seqlen, - max_seqlen, - from_block_size, - to_block_size, - n_rand_blocks, - indices_prng_key=indices_prng_key, - deterministic=deterministic, - last_idx=1024, - )[: (from_seq_len // from_block_size - 2)] - for _ in range(n_heads) - ] - else: - if plan_from_length is None: - plan_from_length, plan_num_rand_blocks = self._get_rand_attn_plan( - from_seq_len, from_block_size, n_rand_blocks - ) - rand_attn = self._bigbird_block_rand_mask_with_head( - from_seq_length=from_seq_len, - to_seq_length=to_seq_len, - from_block_size=from_block_size, - to_block_size=to_block_size, - num_heads=n_heads, - plan_from_length=plan_from_length, - plan_num_rand_blocks=plan_num_rand_blocks, - indices_prng_key=indices_prng_key, - ) - - rand_attn = jnp.stack(rand_attn, axis=0) - rand_attn = jnp.broadcast_to(rand_attn, (bsz,) + rand_attn.shape) - - rand_mask = self._create_rand_mask_from_inputs( - from_blocked_mask, to_blocked_mask, rand_attn, n_heads, n_rand_blocks, bsz, from_seq_len, from_block_size - ) - - blocked_query_matrix = query_layer.reshape(bsz, n_heads, from_seq_len // from_block_size, from_block_size, -1) - blocked_key_matrix = key_layer.reshape(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) - blocked_value_matrix = value_layer.reshape(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) - - shape = (bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1) - gathered_key = self.jax_gather(blocked_key_matrix, rand_attn, batch_dims=2).reshape(*shape) - gathered_value = self.jax_gather(blocked_value_matrix, rand_attn, batch_dims=2).reshape(*shape) - - # 1st PART - # 1st block (global block) attention scores - # q[0] x (k[0], k[1], k[2], k[3], k[4] .... ) - - # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] - first_product = jnp.einsum("bhqd,bhkd->bhqk", blocked_query_matrix[:, :, 0], key_layer) - - first_product = first_product * rsqrt_d - first_product += (1.0 - to_mask) * attn_mask_penalty - first_attn_weights = jax.nn.softmax(first_product, axis=-1) # [bsz, n_heads, from_block_size, to_seq_len] - - # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] - first_context_layer = jnp.einsum("bhqk,bhkd->bhqd", first_attn_weights, value_layer) - first_context_layer = jnp.expand_dims(first_context_layer, 2) - - # 2nd PART - # 2nd block attention scores - # q[1] x (sliding_keys, random_keys, global_keys) - # sliding key blocks -> 2nd, 3rd blocks - # global key blocks -> 1st block - - second_key_mat = jnp.concatenate( - [ - blocked_key_matrix[:, :, 0], - blocked_key_matrix[:, :, 1], - blocked_key_matrix[:, :, 2], - blocked_key_matrix[:, :, -1], - gathered_key[:, :, 0], - ], - axis=2, - ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] - second_value_mat = jnp.concatenate( - [ - blocked_value_matrix[:, :, 0], - blocked_value_matrix[:, :, 1], - blocked_value_matrix[:, :, 2], - blocked_value_matrix[:, :, -1], - gathered_value[:, :, 0], - ], - axis=2, - ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] - - # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] - # ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] - second_product = jnp.einsum("bhqd,bhkd->bhqk", blocked_query_matrix[:, :, 1], second_key_mat) - second_seq_pad = jnp.concatenate( - [ - to_mask[:, :, :, : 3 * to_block_size], - to_mask[:, :, :, -to_block_size:], - jnp.ones([bsz, 1, 1, n_rand_blocks * to_block_size], dtype=to_mask.dtype), - ], - axis=3, - ) - second_rand_pad = jnp.concatenate( - [ - jnp.ones([bsz, n_heads, from_block_size, 4 * to_block_size], dtype=rand_mask.dtype), - rand_mask[:, :, 0], - ], - axis=3, - ) - second_product = second_product * rsqrt_d - second_product += (1.0 - jnp.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty - second_attn_weights = jax.nn.softmax( - second_product, axis=-1 - ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] - - # [bsz, n_heads, from_block_size, (4+r)*to_block_size] x [bsz, n_heads, (4+r)*to_block_size, -1] - # ==> [bsz, n_heads, from_block_size, -1] - second_context_layer = jnp.einsum("bhqk,bhkd->bhqd", second_attn_weights, second_value_mat) - second_context_layer = jnp.expand_dims(second_context_layer, 2) - - # 3rd PART - # Middle blocks attention scores - # q[-2:2] x (sliding_keys, random_keys, global_keys) - # sliding attn is calculated using special trick of shifting tokens as discussed in paper - # random keys are generated by taking random indices as per `rand_attn` - # global keys -> 1st & last block - - exp_blocked_key_matrix = jnp.concatenate( - [blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2], blocked_key_matrix[:, :, 3:-1]], axis=3 - ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] - exp_blocked_value_matrix = jnp.concatenate( - [blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2], blocked_value_matrix[:, :, 3:-1]], - axis=3, - ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] - middle_query_matrix = blocked_query_matrix[:, :, 2:-2] - - # sliding attention scores for q[-2:2] - # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [b, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] - inner_band_product = jnp.einsum("bhlqd,bhlkd->bhlqk", middle_query_matrix, exp_blocked_key_matrix) - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, 3*to_block_size] - inner_band_product = inner_band_product * rsqrt_d - - # randn attention scores for q[-2:2] - # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] - # x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] - rand_band_product = jnp.einsum("bhlqd,bhlkd->bhlqk", middle_query_matrix, gathered_key[:, :, 1:-1]) - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] - rand_band_product = rand_band_product * rsqrt_d - - # Including 1st block (since it's global) - # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] - first_band_product = jnp.einsum("bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, 0]) - first_band_product = first_band_product * rsqrt_d - - # Including last block (since it's global) - # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] - last_band_product = jnp.einsum("bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, -1]) - last_band_product = last_band_product * rsqrt_d - - # masking padded tokens - inner_band_product += (1.0 - band_mask) * attn_mask_penalty - first_band_product += (1.0 - jnp.expand_dims(to_mask[:, :, :, :to_block_size], 3)) * attn_mask_penalty - last_band_product += (1.0 - jnp.expand_dims(to_mask[:, :, :, -to_block_size:], 3)) * attn_mask_penalty - rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * attn_mask_penalty - - # completing attention scores matrix for all q[-2:2] - band_product = jnp.concatenate( - [first_band_product, inner_band_product, rand_band_product, last_band_product], axis=-1 - ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] - - # safely doing softmax since attention matrix is completed - attn_weights = jax.nn.softmax( - band_product, axis=-1 - ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] - - # contribution of sliding keys - # [bsz, n_heads, m//from_block_size-4, from_block_size, 3*to_block_size] - # x [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] - context_layer = jnp.einsum( - "bhlqk,bhlkd->bhlqd", attn_weights[:, :, :, :, to_block_size : 4 * to_block_size], exp_blocked_value_matrix - ) - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] - - # adding contribution of random keys - # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] - # x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] - context_layer += jnp.einsum( - "bhlqk,bhlkd->bhlqd", - attn_weights[:, :, :, :, 4 * to_block_size : -to_block_size], - gathered_value[:, :, 1:-1], - ) - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] - - # adding contribution of global keys - # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] - context_layer += jnp.einsum( - "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, :to_block_size], blocked_value_matrix[:, :, 0] - ) - # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] - context_layer += jnp.einsum( - "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, -to_block_size:], blocked_value_matrix[:, :, -1] - ) - - # 4th PART - # last 2nd token attention scores - # q[-2] x (sliding_keys, random_keys, global_keys) - # sliding key blocks -> last 3 blocks - # global key block -> 1st block - # random key block -> based on indices stored in `randn_attn` - - second_last_key_mat = jnp.concatenate( - [ - blocked_key_matrix[:, :, 0], - blocked_key_matrix[:, :, -3], - blocked_key_matrix[:, :, -2], - blocked_key_matrix[:, :, -1], - gathered_key[:, :, -1], - ], - axis=2, - ) # [bsz, n_heads, (4+n_random_blocks)*to_block_size, -1] - second_last_value_mat = jnp.concatenate( - [ - blocked_value_matrix[:, :, 0], - blocked_value_matrix[:, :, -3], - blocked_value_matrix[:, :, -2], - blocked_value_matrix[:, :, -1], - gathered_value[:, :, -1], - ], - axis=2, - ) # [bsz, n_heads, (4+r)*to_block_size, -1] - - # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] - # ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] - second_last_product = jnp.einsum("bhqd,bhkd->bhqk", blocked_query_matrix[:, :, -2], second_last_key_mat) - second_last_seq_pad = jnp.concatenate( - [ - to_mask[:, :, :, :to_block_size], - to_mask[:, :, :, -3 * to_block_size :], - jnp.ones([bsz, 1, 1, n_rand_blocks * to_block_size], dtype=to_mask.dtype), - ], - axis=3, - ) - second_last_rand_pad = jnp.concatenate( - [ - jnp.ones([bsz, n_heads, from_block_size, 4 * to_block_size], dtype=rand_mask.dtype), - rand_mask[:, :, -1], - ], - axis=3, - ) - second_last_product = second_last_product * rsqrt_d - second_last_product += (1.0 - jnp.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty - second_last_attn_weights = jax.nn.softmax( - second_last_product, axis=-1 - ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] - - # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] - # ==> [bsz, n_heads, from_block_size, -1] - second_last_context_layer = jnp.einsum("bhqk,bhkd->bhqd", second_last_attn_weights, second_last_value_mat) - second_last_context_layer = jnp.expand_dims(second_last_context_layer, 2) - - # 5th PART - # last block (global) attention scores - # q[-1] x (k[0], k[1], k[2], k[3], .... ) - - # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] - last_product = jnp.einsum("bhqd,bhkd->bhqk", blocked_query_matrix[:, :, -1], key_layer) - last_product = last_product * rsqrt_d - last_product += (1.0 - to_mask) * attn_mask_penalty - last_attn_weights = jax.nn.softmax(last_product, axis=-1) # [bsz, n_heads, from_block_size, n] - - # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] - last_context_layer = jnp.einsum("bhqk,bhkd->bhqd", last_attn_weights, value_layer) - last_context_layer = jnp.expand_dims(last_context_layer, 2) - - # combining representations of all tokens - context_layer = jnp.concatenate( - [first_context_layer, second_context_layer, context_layer, second_last_context_layer, last_context_layer], - axis=2, - ) - context_layer = context_layer.reshape(bsz, n_heads, from_seq_len, -1) * from_mask - context_layer = jnp.transpose(context_layer, axes=(0, 2, 1, 3)).reshape(bsz, from_seq_len, -1) - - attention_probs = None - - return context_layer, attention_probs - - @staticmethod - def jax_gather(params, indices, batch_dims=2): - """ - Gather the indices from params correctly (equivalent to tf.gather but with modifications) - - Args: - params: (bsz, n_heads, num_blocks, block_size, head_dim) - indices: (bhlqk", from_blocked_mask[:, 1:-1], rand_mask) - return rand_mask - - @staticmethod - def _get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks): - """ - Gives the plan of where to put random attention. - - Args: - from_seq_length: int. length of from sequence. - from_block_size: int. size of block in from sequence. - num_rand_blocks: int. Number of random chunks per row. - - Returns: - plan_from_length: ending location of from block plan_num_rand_blocks: number of random ending location for - each block - """ - - plan_from_length = [] - plan_num_rand_blocks = [] - if (2 * num_rand_blocks + 5) < (from_seq_length // from_block_size): - plan_from_length.append(int((2 * num_rand_blocks + 5) * from_block_size)) - plan_num_rand_blocks.append(num_rand_blocks) - plan_from_length.append(from_seq_length) - plan_num_rand_blocks.append(0) - elif (num_rand_blocks + 5) < (from_seq_length // from_block_size): - plan_from_length.append(int((num_rand_blocks + 5) * from_block_size)) - plan_num_rand_blocks.append(num_rand_blocks // 2) - plan_from_length.append(from_seq_length) - plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks // 2)) - else: - plan_from_length.append(from_seq_length) - plan_num_rand_blocks.append(num_rand_blocks) - - return plan_from_length, plan_num_rand_blocks - - @staticmethod - def _bigbird_block_rand_mask( - from_seq_length, - to_seq_length, - from_block_size, - to_block_size, - num_rand_blocks, - indices_prng_key: Optional[jax.random.PRNGKey] = None, - deterministic: Optional[bool] = True, - last_idx: Optional[int] = -1, - ): - """ - Create adjacency list of random attention. - - Args: - from_seq_length: int. length of from sequence. - to_seq_length: int. length of to sequence. - from_block_size: int. size of block in from sequence. - to_block_size: int. size of block in to sequence. - num_rand_blocks: int. Number of random chunks per row. - indices_prng_key: jax.random.PRNGKey. PRNG key that is used to perform random jax operations. - deterministic: bool. When False random attention will be used. - last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence, - if positive then num_rand_blocks blocks chosen only up to last_idx. - - Returns: - adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks - """ - # using this method when from_seq_length in [1024, 3072, 4096] - - if from_seq_length // from_block_size != to_seq_length // to_block_size: - raise ValueError("Error the number of blocks needs to be same!") - rand_attn = jnp.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=jnp.int32) - # deterministic nor randomness - if deterministic: - return rand_attn - - middle_seq = jnp.arange(1, to_seq_length // to_block_size - 1, dtype=jnp.int32) - last = to_seq_length // to_block_size - 1 - if last_idx > (2 * to_block_size): - last = (last_idx // to_block_size) - 1 - - r = num_rand_blocks # shorthand - for i in range(1, from_seq_length // from_block_size - 1): - start = i - 2 - end = i - if i == 1: - seq_values = jax.random.permutation(indices_prng_key, middle_seq[2:last])[:r] - rand_attn = rand_attn.at[i - 1].set(seq_values) - elif i == 2: - seq_values = jax.random.permutation(indices_prng_key, middle_seq[3:last])[:r] - rand_attn = rand_attn.at[i - 1].set(seq_values) - elif i == from_seq_length // from_block_size - 3: - seq_values = jax.random.permutation(indices_prng_key, middle_seq[:last])[:r] - rand_attn = rand_attn.at[i - 1].set(seq_values) - # Missing -3: should have been sliced till last-3 - elif i == from_seq_length // from_block_size - 2: - seq_values = jax.random.permutation(indices_prng_key, middle_seq[:last])[:r] - rand_attn = rand_attn.at[i - 1].set(seq_values) - # Missing -4: should have been sliced till last-4 - else: - if start > last: - start = last - seq_values = jax.random.permutation(indices_prng_key, middle_seq[:start])[:r] - rand_attn = rand_attn.at[i - 1].set(seq_values) - elif (end + 1) == last: - seq_values = jax.random.permutation(indices_prng_key, middle_seq[:start])[:r] - rand_attn = rand_attn.at[i - 1].set(seq_values) - else: - concat_values = jnp.concatenate((middle_seq[:start], middle_seq[end + 1 : last])) - seq_values = jax.random.permutation(indices_prng_key, concat_values)[:r] - rand_attn = rand_attn.at[i - 1].set(seq_values) - return rand_attn - - def _bigbird_block_rand_mask_with_head( - self, - from_seq_length, - to_seq_length, - from_block_size, - to_block_size, - num_heads, - plan_from_length, - plan_num_rand_blocks, - indices_prng_key: Optional[jax.random.PRNGKey] = None, - deterministic: Optional[bool] = True, - window_block_left=1, - window_block_right=1, - global_block_top=1, - global_block_bottom=1, - global_block_left=1, - global_block_right=1, - ): - """ - Create adjacency list of random attention. - - Args: - from_seq_length: int. length of from sequence. - to_seq_length: int. length of to sequence. - from_block_size: int. size of block in from sequence. - to_block_size: int. size of block in to sequence. - num_heads: int. total number of heads. - plan_from_length: list. plan from length where num_random_blocks are chosen from. - plan_num_rand_blocks: list. number of rand blocks within the plan. - indices_prng_key: jax.random.PRNGKey. PRNG key that is used to perform random jax operations. - deterministic: bool. When False random attention will be used. - window_block_left: int. number of blocks of window to left of a block. - window_block_right: int. number of blocks of window to right of a block. - global_block_top: int. number of blocks at the top. - global_block_bottom: int. number of blocks at the bottom. - global_block_left: int. Number of blocks globally used to the left. - global_block_right: int. Number of blocks globally used to the right. - - Returns: - adjacency list of size num_head where each element is of size from_seq_length//from_block_size-2 by - num_rand_blocks - """ - # using this method when from_seq_length not in [1024, 3072, 4096] - - if from_seq_length // from_block_size != to_seq_length // to_block_size: - raise ValueError("Error the number of blocks needs to be same!") - - if from_seq_length not in plan_from_length: - raise ValueError("Error from sequence length not in plan!") - - # Total number of blocks in the mmask - num_blocks = from_seq_length // from_block_size - # Number of blocks per plan - plan_block_length = jnp.array(plan_from_length) // from_block_size - # till when to follow plan - max_plan_idx = plan_from_length.index(from_seq_length) - - # Random Attention adjacency list - rand_attn = [ - jnp.zeros((num_blocks, sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=jnp.int32) - for i in range(num_heads) - ] - - # deterministic - if deterministic: - for nh in range(num_heads): - rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] - return rand_attn - - # We will go iteratively over the plan blocks and pick random number of - # Attention blocks from the legally allowed blocks - for plan_idx in range(max_plan_idx + 1): - rnd_r_cnt = 0 - if plan_idx > 0: - # set the row for all from_blocks starting from 0 to - # plan_block_length[plan_idx-1] - # column indx start from plan_block_length[plan_idx-1] and ends at - # plan_block_length[plan_idx] - if plan_num_rand_blocks[plan_idx] > 0: - rnd_r_cnt = int(sum(plan_num_rand_blocks[:plan_idx])) - curr_r_cnt = int(sum(plan_num_rand_blocks[: plan_idx + 1])) - for blk_rw_idx in range(global_block_top, plan_block_length[plan_idx - 1]): - for h in range(num_heads): - single_block_row_attention = self._get_single_block_row_attention( - block_id=blk_rw_idx, - to_start_block_id=plan_block_length[plan_idx - 1], - to_end_block_id=plan_block_length[plan_idx], - num_rand_blocks=plan_num_rand_blocks[plan_idx], - window_block_left=window_block_left, - window_block_right=window_block_right, - global_block_left=global_block_left, - global_block_right=global_block_right, - indices_prng_key=indices_prng_key, - ) - rand_attn[h] = ( - rand_attn[h].at[blk_rw_idx, rnd_r_cnt:curr_r_cnt].set(single_block_row_attention) - ) - - for pl_id in range(plan_idx): - if plan_num_rand_blocks[pl_id] == 0: - continue - for blk_rw_idx in range(plan_block_length[plan_idx - 1], plan_block_length[plan_idx]): - rnd_r_cnt = 0 - to_start_block_id = 0 - if pl_id > 0: - rnd_r_cnt = int(sum(plan_num_rand_blocks[:pl_id])) - to_start_block_id = plan_block_length[pl_id - 1] - curr_r_cnt = int(sum(plan_num_rand_blocks[: pl_id + 1])) - for h in range(num_heads): - single_block_row_attention = self._get_single_block_row_attention( - block_id=blk_rw_idx, - to_start_block_id=to_start_block_id, - to_end_block_id=plan_block_length[pl_id], - num_rand_blocks=plan_num_rand_blocks[pl_id], - window_block_left=window_block_left, - window_block_right=window_block_right, - global_block_left=global_block_left, - global_block_right=global_block_right, - indices_prng_key=indices_prng_key, - ) - rand_attn[h] = ( - rand_attn[h].at[blk_rw_idx, rnd_r_cnt:curr_r_cnt].set(single_block_row_attention) - ) - - if plan_num_rand_blocks[plan_idx] == 0: - continue - curr_r_cnt = int(sum(plan_num_rand_blocks[: plan_idx + 1])) - from_start_block_id = global_block_top - to_start_block_id = 0 - if plan_idx > 0: - rnd_r_cnt = int(sum(plan_num_rand_blocks[:plan_idx])) - from_start_block_id = plan_block_length[plan_idx - 1] - to_start_block_id = plan_block_length[plan_idx - 1] - for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]): - for h in range(num_heads): - single_block_row_attention = self._get_single_block_row_attention( - block_id=blk_rw_idx, - to_start_block_id=to_start_block_id, - to_end_block_id=plan_block_length[plan_idx], - num_rand_blocks=plan_num_rand_blocks[plan_idx], - window_block_left=window_block_left, - window_block_right=window_block_right, - global_block_left=global_block_left, - global_block_right=global_block_right, - indices_prng_key=indices_prng_key, - ) - rand_attn[h] = rand_attn[h].at[blk_rw_idx, rnd_r_cnt:curr_r_cnt].set(single_block_row_attention) - - for nh in range(num_heads): - rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] - return rand_attn - - @staticmethod - def _get_single_block_row_attention( - block_id, - to_start_block_id, - to_end_block_id, - num_rand_blocks, - indices_prng_key: Optional[jax.random.PRNGKey] = None, - window_block_left=1, - window_block_right=1, - global_block_left=1, - global_block_right=1, - ): - """ - For a single row block get random row attention. - - Args: - block_id: int. block id of row. - to_start_block_id: int. random attention column start id. - to_end_block_id: int. random attention column end id. - num_rand_blocks: int. number of random blocks to be selected. - indices_prng_key: jax.random.PRNGKey. PRNG key that is used to perform random jax operations - window_block_left: int. number of blocks of window to left of a block. - window_block_right: int. number of blocks of window to right of a block. - global_block_left: int. Number of blocks globally used to the left. - global_block_right: int. Number of blocks globally used to the right. - - Returns: - row containing the random attention vector of size num_rand_blocks. - """ - # list of to_blocks from which to choose random attention - to_block_list = jnp.arange(to_start_block_id, to_end_block_id, dtype=jnp.int32) - # permute the blocks - perm_block = jax.random.permutation(indices_prng_key, to_block_list) - - # illegal blocks for the current block id, using window - illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1)) - - # Add blocks at the start and at the end - illegal_blocks.extend(list(range(global_block_left))) - illegal_blocks.extend(list(range(to_end_block_id - global_block_right, to_end_block_id))) - - # The second from_block cannot choose random attention on second last to_block - if block_id == 1: - illegal_blocks.append(to_end_block_id - 2) - - # The second last from_block cannot choose random attention on second to_block - if block_id == to_end_block_id - 2: - illegal_blocks.append(1) - - selected_random_blocks = [] - - for i in range(to_end_block_id - to_start_block_id): - if perm_block[i] not in illegal_blocks: - selected_random_blocks.append(perm_block[i]) - if len(selected_random_blocks) == num_rand_blocks: - break - return jnp.array(selected_random_blocks, dtype=jnp.int32) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->BigBird -class FlaxBigBirdSelfOutput(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, input_tensor, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class FlaxBigBirdAttention(nn.Module): - config: BigBirdConfig - layer_id: Optional[int] = None - causal: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - if self.config.attention_type == "original_full": - self.self = FlaxBigBirdSelfAttention(self.config, causal=self.causal, dtype=self.dtype) - elif self.config.attention_type == "block_sparse": - self.self = FlaxBigBirdBlockSparseAttention(self.config, block_sparse_seed=self.layer_id, dtype=self.dtype) - else: - raise ValueError( - f"Your `config.attention_type` is {self.config.attention_type} but it can either be `original_full` or" - " `block_sparse`" - ) - - self.output = FlaxBigBirdSelfOutput(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states=None, - init_cache=False, - deterministic=True, - output_attentions: bool = False, - ): - # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) - # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable - # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - if self.config.attention_type == "original_full": - attn_outputs = self.self( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=key_value_states, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - else: - attn_outputs = self.self( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_outputs[1],) - - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->BigBird -class FlaxBigBirdIntermediate(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.activation = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->BigBird -class FlaxBigBirdOutput(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states, attention_output, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + attention_output) - return hidden_states - - -class FlaxBigBirdLayer(nn.Module): - config: BigBirdConfig - layer_id: Optional[int] = None - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.attention = FlaxBigBirdAttention( - self.config, layer_id=self.layer_id, causal=self.config.is_decoder, dtype=self.dtype - ) - self.intermediate = FlaxBigBirdIntermediate(self.config, dtype=self.dtype) - self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype) - if self.config.add_cross_attention: - self.crossattention = FlaxBigBirdAttention(self.config, causal=False, dtype=self.dtype) - - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer.__call__ with Bert->BigBird - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - ): - # Self Attention - attention_outputs = self.attention( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = attention_outputs[0] - - # Cross-Attention Block - if encoder_hidden_states is not None: - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=encoder_hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = cross_attention_outputs[0] - - hidden_states = self.intermediate(attention_output) - hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[1],) - if encoder_hidden_states is not None: - outputs += (cross_attention_outputs[1],) - return outputs - - -class FlaxBigBirdLayerCollection(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - if self.gradient_checkpointing: - FlaxBigBirdCheckpointLayer = remat(FlaxBigBirdLayer, static_argnums=(5, 6, 7)) - self.layers = [ - FlaxBigBirdCheckpointLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - else: - self.layers = [ - FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection.__call__ with Bert->BigBird - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - # Check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.shape[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for " - f" {head_mask.shape[0]}." - ) - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer( - hidden_states, - attention_mask, - head_mask[i] if head_mask is not None else None, - encoder_hidden_states, - encoder_attention_mask, - init_cache, - deterministic, - output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->BigBird -class FlaxBigBirdEncoder(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.layer = FlaxBigBirdLayerCollection( - self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.layer( - hidden_states, - attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPredictionHeadTransform with Bert->BigBird -class FlaxBigBirdPredictionHeadTransform(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) - self.activation = ACT2FN[self.config.hidden_act] - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return self.LayerNorm(hidden_states) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLMPredictionHead with Bert->BigBird, np.ndarray->jnp.ndarray -class FlaxBigBirdLMPredictionHead(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.transform = FlaxBigBirdPredictionHeadTransform(self.config, dtype=self.dtype) - self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) - self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) - - def __call__(self, hidden_states, shared_embedding=None): - hidden_states = self.transform(hidden_states) - - if shared_embedding is not None: - hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - hidden_states = self.decoder(hidden_states) - - bias = jnp.asarray(self.bias, self.dtype) - hidden_states += bias - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOnlyMLMHead with Bert->BigBird -class FlaxBigBirdOnlyMLMHead(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.predictions = FlaxBigBirdLMPredictionHead(self.config, dtype=self.dtype) - - def __call__(self, hidden_states, shared_embedding=None): - hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding) - return hidden_states - - -class FlaxBigBirdPreTrainingHeads(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.predictions = FlaxBigBirdLMPredictionHead(self.config, dtype=self.dtype) - self.seq_relationship = nn.Dense(2, dtype=self.dtype) - - def __call__(self, hidden_states, pooled_output, shared_embedding=None): - prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BigBirdConfig - base_model_prefix = "bert" - module_class: nn.Module = None - - def __init__( - self, - config: BigBirdConfig, - input_shape: Optional[tuple] = None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - gradient_checkpointing: bool = False, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) - if config.attention_type == "block_sparse" and input_shape is None: - input_shape = (1, 12 * config.block_size) - elif input_shape is None: - input_shape = (1, 1) - - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing - def enable_gradient_checkpointing(self): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=True, - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - token_type_ids = jnp.zeros_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - attention_mask = jnp.ones_like(input_ids) - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - params_rng, dropout_rng, indices_rng = jax.random.split(rng, num=3) - rngs = {"params": params_rng, "dropout": dropout_rng, "indices": indices_rng} - - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - return_dict=False, - ) - - random_params = module_init_outputs["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length), dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - params: Optional[dict] = None, - dropout_rng: Optional[jax.random.PRNGKey] = None, - indices_rng: Optional[jax.random.PRNGKey] = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - past_key_values: Optional[dict] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # init input tensors if not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - if head_mask is None: - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - # Handle any PRNG if needed - rngs = {} - if indices_rng is not None: - rngs["indices"] = indices_rng - - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - if self.config.add_cross_attention: - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed - # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be - # changed by FlaxBigBirdAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - else: - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - ) - - return outputs - - -class FlaxBigBirdModule(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - add_pooling_layer: bool = True - gradient_checkpointing: bool = False - - def setup(self): - self.embeddings = FlaxBigBirdEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxBigBirdEncoder( - self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.pooler = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - hidden_states = self.embeddings( - input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic - ) - outputs = self.encoder( - hidden_states, - attention_mask, - head_mask=head_mask, - deterministic=deterministic, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - - pooled = nn.tanh(self.pooler(hidden_states[:, 0, :])) if self.add_pooling_layer else None - - if not return_dict: - # if pooled is None, don't return it - if pooled is None: - return (hidden_states,) + outputs[1:] - return (hidden_states, pooled) + outputs[1:] - - return FlaxBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=hidden_states, - pooler_output=pooled, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - "The bare BigBird Model transformer outputting raw hidden-states without any specific head on top.", - BIG_BIRD_START_DOCSTRING, -) -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModel with Bert->BigBird -class FlaxBigBirdModel(FlaxBigBirdPreTrainedModel): - module_class = FlaxBigBirdModule - - -append_call_sample_docstring(FlaxBigBirdModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForPreTrainingModule with Bert->BigBird -class FlaxBigBirdForPreTrainingModule(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBigBirdModule( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.cls = FlaxBigBirdPreTrainingHeads(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if self.config.tie_word_embeddings: - shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - hidden_states = outputs[0] - pooled_output = outputs[1] - - prediction_scores, seq_relationship_score = self.cls( - hidden_states, pooled_output, shared_embedding=shared_embedding - ) - - if not return_dict: - return (prediction_scores, seq_relationship_score) + outputs[2:] - - return FlaxBigBirdForPreTrainingOutput( - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - BigBird Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next - sentence prediction (classification)` head. - """, - BIG_BIRD_START_DOCSTRING, -) -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForPreTraining with Bert->BigBird -class FlaxBigBirdForPreTraining(FlaxBigBirdPreTrainedModel): - module_class = FlaxBigBirdForPreTrainingModule - - -FLAX_BIG_BIRD_FOR_PRETRAINING_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxBigBirdForPreTraining - - >>> tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base") - >>> model = FlaxBigBirdForPreTraining.from_pretrained("google/bigbird-roberta-base") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.prediction_logits - >>> seq_relationship_logits = outputs.seq_relationship_logits - ``` -""" - -overwrite_call_docstring( - FlaxBigBirdForPreTraining, - BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BIG_BIRD_FOR_PRETRAINING_DOCSTRING, -) -append_replace_return_docstrings( - FlaxBigBirdForPreTraining, output_type=FlaxBigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMaskedLMModule with Bert->BigBird -class FlaxBigBirdForMaskedLMModule(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBigBirdModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.cls(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxMaskedLMOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING) -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMaskedLM with Bert->BigBird -class FlaxBigBirdForMaskedLM(FlaxBigBirdPreTrainedModel): - module_class = FlaxBigBirdForMaskedLMModule - - -append_call_sample_docstring(FlaxBigBirdForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC) - - -class FlaxBigBirdClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(classifier_dropout) - self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__(self, features, deterministic=True): - x = features[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x, deterministic=deterministic) - x = self.dense(x) - x = ACT2FN[self.config.hidden_act](x) - x = self.dropout(x, deterministic=deterministic) - x = self.out_proj(x) - return x - - -class FlaxBigBirdForSequenceClassificationModule(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBigBirdModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.classifier = FlaxBigBirdClassificationHead(self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - logits = self.classifier(sequence_output, deterministic=deterministic) - - if not return_dict: - return (logits,) + outputs[2:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - BigBird Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - BIG_BIRD_START_DOCSTRING, -) -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForSequenceClassification with Bert->BigBird -class FlaxBigBirdForSequenceClassification(FlaxBigBirdPreTrainedModel): - module_class = FlaxBigBirdForSequenceClassificationModule - - -append_call_sample_docstring( - FlaxBigBirdForSequenceClassification, - _CHECKPOINT_FOR_DOC, - FlaxSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->BigBird -class FlaxBigBirdForMultipleChoiceModule(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBigBirdModule( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.classifier = nn.Dense(1, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - num_choices = input_ids.shape[1] - input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None - token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None - position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None - - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) - - reshaped_logits = logits.reshape(-1, num_choices) - - if not return_dict: - return (reshaped_logits,) + outputs[2:] - - return FlaxMultipleChoiceModelOutput( - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - BigBird Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - BIG_BIRD_START_DOCSTRING, -) -class FlaxBigBirdForMultipleChoice(FlaxBigBirdPreTrainedModel): - module_class = FlaxBigBirdForMultipleChoiceModule - - def __init__( - self, - config: BigBirdConfig, - input_shape: Optional[tuple] = None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - if config.attention_type == "block_sparse" and input_shape is None: - input_shape = (1, 1, 12 * config.block_size) - elif input_shape is None: - input_shape = (1, 1) - super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - -overwrite_call_docstring( - FlaxBigBirdForMultipleChoice, BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") -) -append_call_sample_docstring( - FlaxBigBirdForMultipleChoice, - _CHECKPOINT_FOR_DOC, - FlaxMultipleChoiceModelOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->BigBird -class FlaxBigBirdForTokenClassificationModule(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBigBirdModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(rate=classifier_dropout) - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.classifier(hidden_states) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxTokenClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - BigBird Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - BIG_BIRD_START_DOCSTRING, -) -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassification with Bert->BigBird -class FlaxBigBirdForTokenClassification(FlaxBigBirdPreTrainedModel): - module_class = FlaxBigBirdForTokenClassificationModule - - -append_call_sample_docstring( - FlaxBigBirdForTokenClassification, - _CHECKPOINT_FOR_DOC, - FlaxTokenClassifierOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxBigBirdForQuestionAnsweringHead(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.intermediate = FlaxBigBirdIntermediate(self.config, dtype=self.dtype) - self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype) - self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__(self, encoder_output, deterministic=True): - hidden_states = self.dropout(encoder_output, deterministic=deterministic) - hidden_states = self.intermediate(hidden_states) - hidden_states = self.output(hidden_states, encoder_output) - hidden_states = self.qa_outputs(hidden_states) - return hidden_states - - -class FlaxBigBirdForQuestionAnsweringModule(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - add_pooling_layer: bool = False - gradient_checkpointing: bool = False - - def setup(self): - self.config.num_labels = 2 - self.bert = FlaxBigBirdModule( - self.config, - dtype=self.dtype, - add_pooling_layer=self.add_pooling_layer, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.qa_classifier = FlaxBigBirdForQuestionAnsweringHead(self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - logits_mask=None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - pooled_output = outputs[1] if self.add_pooling_layer else None - logits = self.qa_classifier(hidden_states, deterministic=deterministic) - - if logits_mask is not None: - # removing question tokens from the competition - logits = logits - logits_mask * 1e6 - - start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - return (start_logits, end_logits) + outputs[1:] - - return FlaxBigBirdForQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - pooled_output=pooled_output, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - BigBird Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - BIG_BIRD_START_DOCSTRING, -) -class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel): - module_class = FlaxBigBirdForQuestionAnsweringModule - - @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - question_lengths=None, - params: Optional[dict] = None, - dropout_rng: Optional[jax.random.PRNGKey] = None, - indices_rng: Optional[jax.random.PRNGKey] = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - if head_mask is None: - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - if question_lengths is None and input_ids is not None: - # assuming input_ids format: context - question_lengths = jnp.argmax((input_ids == self.config.sep_token_id).astype("i4"), axis=-1) + 1 - question_lengths = jnp.expand_dims(question_lengths, axis=1) - - seqlen = input_ids.shape[1] - - logits_mask = None - if question_lengths is not None: - # setting lengths logits to `-inf` - logits_mask = self.prepare_question_mask(question_lengths, seqlen) - if token_type_ids is None: - token_type_ids = (~logits_mask).astype("i4") - logits_mask = jnp.expand_dims(logits_mask, axis=2) - logits_mask = logits_mask.at[:, 0].set(False) - - # init input tensors if not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - if indices_rng is not None: - rngs["indices"] = indices_rng - - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids, - jnp.array(position_ids, dtype="i4"), - jnp.array(head_mask, dtype="i4"), - logits_mask, - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) - - @staticmethod - def prepare_question_mask(q_lengths, maxlen: int): - # q_lengths -> (bz, 1) - mask = jnp.arange(0, maxlen) - mask = jnp.expand_dims(mask, axis=0) < q_lengths - return mask - - -append_call_sample_docstring( - FlaxBigBirdForQuestionAnswering, - _CHECKPOINT_FOR_DOC, - FlaxBigBirdForQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxBigBirdForCausalLMModule(nn.Module): - config: BigBirdConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.bert = FlaxBigBirdModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - token_type_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.bert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.cls(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxCausalLMOutputWithCrossAttentions( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - """ - BigBird Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for - autoregressive tasks. - """, - BIG_BIRD_START_DOCSTRING, -) -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->BigBird -class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel): - module_class = FlaxBigBirdForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyway. - # Thus, we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxBigBirdForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutputWithCrossAttentions, - _CONFIG_FOR_DOC, -) - - -__all__ = [ - "FlaxBigBirdForCausalLM", - "FlaxBigBirdForMaskedLM", - "FlaxBigBirdForMultipleChoice", - "FlaxBigBirdForPreTraining", - "FlaxBigBirdForQuestionAnswering", - "FlaxBigBirdForSequenceClassification", - "FlaxBigBirdForTokenClassification", - "FlaxBigBirdModel", - "FlaxBigBirdPreTrainedModel", -] diff --git a/src/transformers/models/bigbird_pegasus/convert_bigbird_pegasus_tf_to_pytorch.py b/src/transformers/models/bigbird_pegasus/convert_bigbird_pegasus_tf_to_pytorch.py deleted file mode 100644 index d0a312ebc11f..000000000000 --- a/src/transformers/models/bigbird_pegasus/convert_bigbird_pegasus_tf_to_pytorch.py +++ /dev/null @@ -1,169 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -import tensorflow as tf -import torch -from tqdm import tqdm - -from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration - - -INIT_COMMON = [ - # tf -> hf - ("/", "."), - ("layer_", "layers."), - ("kernel", "weight"), - ("beta", "bias"), - ("gamma", "weight"), - ("pegasus", "model"), -] -END_COMMON = [ - (".output.dense", ".fc2"), - ("intermediate.LayerNorm", "final_layer_norm"), - ("intermediate.dense", "fc1"), -] - -DECODER_PATTERNS = ( - INIT_COMMON - + [ - ("attention.self.LayerNorm", "self_attn_layer_norm"), - ("attention.output.dense", "self_attn.out_proj"), - ("attention.self", "self_attn"), - ("attention.encdec.LayerNorm", "encoder_attn_layer_norm"), - ("attention.encdec_output.dense", "encoder_attn.out_proj"), - ("attention.encdec", "encoder_attn"), - ("key", "k_proj"), - ("value", "v_proj"), - ("query", "q_proj"), - ("decoder.LayerNorm", "decoder.layernorm_embedding"), - ] - + END_COMMON -) - -REMAINING_PATTERNS = ( - INIT_COMMON - + [ - ("embeddings.word_embeddings", "shared.weight"), - ("embeddings.position_embeddings", "embed_positions.weight"), - ("attention.self.LayerNorm", "self_attn_layer_norm"), - ("attention.output.dense", "self_attn.output"), - ("attention.self", "self_attn.self"), - ("encoder.LayerNorm", "encoder.layernorm_embedding"), - ] - + END_COMMON -) - -KEYS_TO_IGNORE = [ - "encdec/key/bias", - "encdec/query/bias", - "encdec/value/bias", - "self/key/bias", - "self/query/bias", - "self/value/bias", - "encdec_output/dense/bias", - "attention/output/dense/bias", -] - - -def rename_state_dict_key(k, patterns): - for tf_name, hf_name in patterns: - k = k.replace(tf_name, hf_name) - return k - - -def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration: - cfg = BigBirdPegasusConfig(**config_update) - torch_model = BigBirdPegasusForConditionalGeneration(cfg) - state_dict = torch_model.state_dict() - mapping = {} - - # separating decoder weights - decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")} - remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")} - - for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"): - conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE] - if any(conditions): - continue - patterns = DECODER_PATTERNS - new_k = rename_state_dict_key(k, patterns) - if new_k not in state_dict: - raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") - if any(i in k for i in ["dense", "query", "key", "value"]): - v = v.T - mapping[new_k] = torch.from_numpy(v) - assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}" - - for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"): - conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE] - if any(conditions): - continue - patterns = REMAINING_PATTERNS - new_k = rename_state_dict_key(k, patterns) - if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings": - raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") - if any(i in k for i in ["dense", "query", "key", "value"]): - v = v.T - mapping[new_k] = torch.from_numpy(v) - if k != "pegasus/embeddings/position_embeddings": - assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}" - - mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"] - mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight") - missing, extra = torch_model.load_state_dict(mapping, strict=False) - unexpected_missing = [ - k - for k in missing - if k - not in [ - "final_logits_bias", - "model.encoder.embed_tokens.weight", - "model.decoder.embed_tokens.weight", - "lm_head.weight", - ] - ] - assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}" - assert extra == [], f"no matches found for the following tf keys {extra}" - return torch_model - - -def get_tf_weights_as_numpy(path) -> dict: - init_vars = tf.train.list_variables(path) - tf_weights = {} - ignore_name = ["global_step"] - for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"): - skip_key = any(pat in name for pat in ignore_name) - if skip_key: - continue - array = tf.train.load_variable(path, name) - tf_weights[name] = array - return tf_weights - - -def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict): - tf_weights = get_tf_weights_as_numpy(ckpt_path) - torch_model = convert_bigbird_pegasus(tf_weights, config_update) - torch_model.save_pretrained(save_dir) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables") - parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.") - args = parser.parse_args() - config_update = {} - convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update) diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py deleted file mode 100644 index 8b147211881b..000000000000 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ /dev/null @@ -1,1508 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax Blenderbot model.""" - -import math -import random -from functools import partial -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxSeq2SeqLMOutput, - FlaxSeq2SeqModelOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_blenderbot import BlenderbotConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "BlenderbotConfig" -_CHECKPOINT_FOR_DOC = "facebook/blenderbot-400M-distill" - - -BLENDERBOT_START_DOCSTRING = r""" - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`BlenderbotConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -BLENDERBOT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -BLENDERBOT_ENCODE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -BLENDERBOT_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: - """ - Shift input ids one token to the right. - """ - shifted_input_ids = jnp.zeros_like(input_ids) - shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) - shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) - - shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) - return shifted_input_ids - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Blenderbot -class FlaxBlenderbotAttention(nn.Module): - config: BlenderbotConfig - embed_dim: int - num_heads: int - dropout: float = 0.0 - causal: bool = False - bias: bool = True - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self) -> None: - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {self.num_heads})." - ) - - dense = partial( - nn.Dense, - self.embed_dim, - use_bias=self.bias, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() - self.out_proj = dense() - - self.dropout_layer = nn.Dropout(rate=self.dropout) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states: jnp.ndarray, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states) - value_states = self.v_proj(key_value_states) - else: - # self_attention - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Blenderbot -class FlaxBlenderbotEncoderLayer(nn.Module): - config: BlenderbotConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxBlenderbotAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.encoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->Blenderbot -class FlaxBlenderbotEncoderLayerCollection(nn.Module): - config: BlenderbotConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxBlenderbotEncoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.encoder_layers) - ] - self.layerdrop = self.config.encoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): # skip the layer - layer_outputs = (None, None) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions, - deterministic, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Blenderbot -class FlaxBlenderbotDecoderLayer(nn.Module): - config: BlenderbotConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxBlenderbotAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - causal=True, - dtype=self.dtype, - ) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.encoder_attn = FlaxBlenderbotAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.fc1 = nn.Dense( - self.config.decoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Blenderbot -class FlaxBlenderbotDecoderLayerCollection(nn.Module): - config: BlenderbotConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxBlenderbotDecoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.decoder_layers) - ] - self.layerdrop = self.config.decoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): - layer_outputs = (None, None, None) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - deterministic=deterministic, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class FlaxBlenderbotEncoder(nn.Module): - config: BlenderbotConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_source_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 - - self.embed_positions = nn.Embed( - self.config.max_position_embeddings, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.layers = FlaxBlenderbotEncoderLayerCollection(self.config, self.dtype) - self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(position_ids) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - last_hidden_states = outputs[0] - last_hidden_states = self.layer_norm(last_hidden_states) - - # update the last element in `hidden_states` after applying `layernorm` above - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_states,) - - if not return_dict: - outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=last_hidden_states, - hidden_states=hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxBlenderbotDecoder(nn.Module): - config: BlenderbotConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_target_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 - - self.embed_positions = nn.Embed( - self.config.max_position_embeddings, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.layers = FlaxBlenderbotDecoderLayerCollection(self.config, self.dtype) - self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - # embed positions - positions = self.embed_positions(position_ids) - - hidden_states = inputs_embeds + positions - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_states = outputs[0] - last_hidden_states = self.layer_norm(last_hidden_states) - - # update the last element in `hidden_states` after applying `layernorm` above - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_states,) - - if not return_dict: - outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=last_hidden_states, - hidden_states=hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->Blenderbot -class FlaxBlenderbotModule(nn.Module): - config: BlenderbotConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - - self.encoder = FlaxBlenderbotEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - self.decoder = FlaxBlenderbotDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel): - config_class = BlenderbotConfig - base_model_prefix: str = "model" - module_class: nn.Module = None - - def __init__( - self, - config: BlenderbotConfig, - input_shape: tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - # make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule - input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) - attention_mask = jnp.ones_like(input_ids) - decoder_input_ids = input_ids - decoder_attention_mask = jnp.ones_like(input_ids) - - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(BLENDERBOT_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=BlenderbotConfig) - def encode( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxBlenderbotForConditionalGeneration - - >>> model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_ids, attention_mask, position_ids, **kwargs) - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - @add_start_docstrings(BLENDERBOT_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=BlenderbotConfig - ) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoTokenizer, FlaxBlenderbotForConditionalGeneration - - >>> model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> last_decoder_hidden_states = outputs.last_hidden_state - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxBlenderbotAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - @add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # prepare decoder inputs - if decoder_input_ids is None: - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - if decoder_position_ids is None: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - -@add_start_docstrings( - "The bare MBart Model transformer outputting raw hidden-states without any specific head on top.", - BLENDERBOT_START_DOCSTRING, -) -class FlaxBlenderbotModel(FlaxBlenderbotPreTrainedModel): - config: BlenderbotConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - module_class = FlaxBlenderbotModule - - -append_call_sample_docstring(FlaxBlenderbotModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->Blenderbot -class FlaxBlenderbotForConditionalGenerationModule(nn.Module): - config: BlenderbotConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.model = FlaxBlenderbotModule(config=self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.model.shared.num_embeddings, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = self.model.variables["params"]["shared"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype)) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return output - - return FlaxSeq2SeqLMOutput( - logits=lm_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - "The Blenderbot Model with a language modeling head. Can be used for summarization.", BLENDERBOT_START_DOCSTRING -) -class FlaxBlenderbotForConditionalGeneration(FlaxBlenderbotPreTrainedModel): - module_class = FlaxBlenderbotForConditionalGenerationModule - dtype: jnp.dtype = jnp.float32 - - @add_start_docstrings(BLENDERBOT_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=BlenderbotConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoTokenizer, FlaxBlenderbotForConditionalGeneration - - >>> model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxBlenderbotAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - outputs = decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = module.model.variables["params"]["shared"]["embedding"] - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = module.lm_head(hidden_states) - - lm_logits += module.final_logits_bias - return lm_logits, outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 - return model_kwargs - - -FLAX_BLENDERBOT_CONDITIONAL_GENERATION_DOCSTRING = r""" - Returns: - - Conversation example:: - - ```py - >>> from transformers import AutoTokenizer, FlaxBlenderbotForConditionalGeneration - - >>> model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") - - >>> UTTERANCE = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([UTTERANCE], max_length=1024, return_tensors="np") - - >>> # Generate Reply - >>> reply_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=5, early_stopping=True).sequences - >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in reply_ids]) - ``` -""" - -overwrite_call_docstring( - FlaxBlenderbotForConditionalGeneration, - BLENDERBOT_INPUTS_DOCSTRING + FLAX_BLENDERBOT_CONDITIONAL_GENERATION_DOCSTRING, -) -append_replace_return_docstrings( - FlaxBlenderbotForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC -) - - -__all__ = ["FlaxBlenderbotForConditionalGeneration", "FlaxBlenderbotModel", "FlaxBlenderbotPreTrainedModel"] diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py deleted file mode 100644 index 78f4f6a6761e..000000000000 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ /dev/null @@ -1,1557 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Facebook, Inc and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 Blenderbot model.""" - -from __future__ import annotations - -import os -import random -import warnings - -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPastAndCrossAttentions, - TFSeq2SeqLMOutput, - TFSeq2SeqModelOutput, -) - -# Public API -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFPreTrainedModel, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_end_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_blenderbot import BlenderbotConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/blenderbot-400M-distill" -_CONFIG_FOR_DOC = "BlenderbotConfig" - - -LARGE_NEGATIVE = -1e8 - - -# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right -def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - pad_token_id = tf.cast(pad_token_id, input_ids.dtype) - decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) - start_tokens = tf.fill( - (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) - ) - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz = input_ids_shape[0] - tgt_len = input_ids_shape[1] - mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE - mask_cond = tf.range(shape_list(mask)[-1]) - - mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) - - return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -class TFBlenderbotLearnedPositionalEmbedding(keras.layers.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): - super().__init__(num_embeddings, embedding_dim, **kwargs) - - def call( - self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None - ): - """Input is expected to be of size [bsz x seqlen].""" - if position_ids is None: - seq_len = input_shape[1] - position_ids = tf.range(seq_len, delta=1, name="range") - position_ids += past_key_values_length - - return super().call(tf.cast(position_ids, dtype=tf.int32)) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Blenderbot -class TFBlenderbotAttention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartEncoderLayer with MBart->Blenderbot -class TFBlenderbotEncoderLayer(keras.layers.Layer): - def __init__(self, config: BlenderbotConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFBlenderbotAttention( - self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" - ) - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - layer_head_mask: tf.Tensor, - training: bool | None = False, - ): - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* - attention_mask (`tf.Tensor`): attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - *(encoder_attention_heads,)* - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask - ) - - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) - - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - return hidden_states, self_attn_weights - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.encoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer with MBart->Blenderbot -class TFBlenderbotDecoderLayer(keras.layers.Layer): - def __init__(self, config: BlenderbotConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFBlenderbotAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - name="self_attn", - is_decoder=True, - ) - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.encoder_attn = TFBlenderbotAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - name="encoder_attn", - is_decoder=True, - ) - self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") - self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - cross_attn_layer_head_mask: tf.Tensor | None = None, - past_key_value: tuple[tf.Tensor] | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* - attention_mask (`tf.Tensor`): attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - encoder_hidden_states (`tf.Tensor`): - cross attention input to the layer of shape *(batch, seq_len, embed_dim)* - encoder_attention_mask (`tf.Tensor`): encoder attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - *(decoder_attention_heads,)* - cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. - *(decoder_attention_heads,)* - past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - return ( - hidden_states, - self_attn_weights, - cross_attn_weights, - present_key_value, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "encoder_attn", None) is not None: - with tf.name_scope(self.encoder_attn.name): - self.encoder_attn.build(None) - if getattr(self, "encoder_attn_layer_norm", None) is not None: - with tf.name_scope(self.encoder_attn_layer_norm.name): - self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.decoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFBlenderbotPreTrainedModel(TFPreTrainedModel): - config_class = BlenderbotConfig - base_model_prefix = "model" - - -BLENDERBOT_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`BlenderbotConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -BLENDERBOT_GENERATION_EXAMPLE = r""" - Conversation example:: - - ```py - >>> from transformers import AutoTokenizer, TFBlenderbotForConditionalGeneration - - >>> mname = "facebook/blenderbot-400M-distill" - >>> model = TFBlenderbotForConditionalGeneration.from_pretrained(mname) - >>> tokenizer = AutoTokenizer.from_pretrained(mname) - >>> UTTERANCE = "My friends are cool but they eat too many carbs." - >>> print("Human: ", UTTERANCE) - - >>> inputs = tokenizer([UTTERANCE], return_tensors="tf") - >>> reply_ids = model.generate(**inputs) - >>> print("Bot: ", tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]) - - >>> REPLY = "I'm not sure" - >>> print("Human: ", REPLY) - >>> NEXT_UTTERANCE = ( - ... "My friends are cool but they eat too many carbs. That's unfortunate. " - ... "Are they trying to lose weight or are they just trying to be healthier? " - ... " I'm not sure." - ... ) - >>> inputs = tokenizer([NEXT_UTTERANCE], return_tensors="tf") - >>> next_reply_ids = model.generate(**inputs) - >>> print("Bot: ", tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0]) - ``` -""" - -BLENDERBOT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - Blenderbot uses the `bos_token_id` as the starting token for `decoder_input_ids` generation. If - `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. - decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tf.FloatTensor`, *optional*): - hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - of shape `(batch_size, sequence_length, hidden_size)` is a sequence of - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@keras_serializable -class TFBlenderbotEncoder(keras.layers.Layer): - config_class = BlenderbotConfig - """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`TFBlenderbotEncoderLayer`]. - - Args: - config: BlenderbotConfig - """ - - def __init__(self, config: BlenderbotConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.dropout = keras.layers.Dropout(config.dropout) - self.layerdrop = config.encoder_layerdrop - self.padding_idx = config.pad_token_id - self.max_source_positions = config.max_position_embeddings - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - - self.embed_tokens = embed_tokens - self.embed_positions = TFBlenderbotLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.layers = [TFBlenderbotEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - """ - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value - in the config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. This argument can be used only in eager mode, in graph mode the value in the config - will be used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(input_shape) - hidden_states = inputs_embeds + embed_pos - hidden_states = self.dropout(hidden_states, training=training) - - # check attention mask and invert - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask) - else: - attention_mask = None - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - tf.debugging.assert_equal( - shape_list(head_mask)[0], - len(self.layers), - message=( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(head_mask)[0]}." - ), - ) - - # encoder layers - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): # skip the layer - continue - - hidden_states, attn = encoder_layer( - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - ) - - if output_attentions: - all_attentions += (attn,) - - hidden_states = self.layer_norm(hidden_states) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFBlenderbotDecoder(keras.layers.Layer): - config_class = BlenderbotConfig - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFBlenderbotDecoderLayer`] - - Args: - config: BlenderbotConfig - embed_tokens: output embedding - """ - - def __init__(self, config: BlenderbotConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.padding_idx = config.pad_token_id - self.embed_tokens = embed_tokens - self.layerdrop = config.decoder_layerdrop - self.embed_positions = TFBlenderbotLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - self.layers = [TFBlenderbotDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - - self.dropout = keras.layers.Dropout(config.dropout) - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - position_ids=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up - decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value - in the config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. This argument can be used only in eager mode, in graph mode the value in the config - will be used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 - - # embed positions - if position_ids is None: - positions = self.embed_positions(input_shape, past_key_values_length) - else: - positions = self.embed_positions(input_shape, position_ids=position_ids) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - hidden_states = inputs_embeds - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) - else: - combined_attention_mask = _expand_mask( - tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] - ) - - if attention_mask is not None: - combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) - - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) - - hidden_states = hidden_states + positions - hidden_states = self.dropout(hidden_states, training=training) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None - present_key_values = () if use_cache else None - - # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired - for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: - if attn_mask is not None: - tf.debugging.assert_equal( - shape_list(attn_mask)[0], - len(self.layers), - message=( - f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(attn_mask)[0]}." - ), - ) - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - dropout_probability = random.uniform(0, 1) - - if training and (dropout_probability < self.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - past_key_value=past_key_value, - ) - - if use_cache: - present_key_values += (present_key_value,) - - if output_attentions: - all_self_attns += (layer_self_attn,) - - if encoder_hidden_states is not None: - all_cross_attns += (layer_cross_attn,) - - hidden_states = self.layer_norm(hidden_states) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns - else: - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFBlenderbotMainLayer(keras.layers.Layer): - config_class = BlenderbotConfig - - def __init__(self, config: BlenderbotConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.shared = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.d_model, - embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), - name="model.shared", - ) - # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) - self.shared.load_weight_prefix = "model.shared" - - self.encoder = TFBlenderbotEncoder(config, self.shared, name="encoder") - self.decoder = TFBlenderbotDecoder(config, self.shared, name="decoder") - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.embed_tokens = self.shared - self.decoder.embed_tokens = self.shared - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - decoder_position_ids=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values=None, - inputs_embeds=None, - decoder_inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - **kwargs, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): - encoder_outputs = TFBaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not return_dict and not isinstance(encoder_outputs, tuple): - encoder_outputs = encoder_outputs.to_tuple() - - decoder_outputs = self.decoder( - decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return TFSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - # The shared/tied weights expect to be in the model base namespace - # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than - # the current one. - with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): - self.shared.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - "The bare BLENDERBOT Model outputting raw hidden-states without any specific head on top.", - BLENDERBOT_START_DOCSTRING, -) -class TFBlenderbotModel(TFBlenderbotPreTrainedModel): - def __init__(self, config: BlenderbotConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.model = TFBlenderbotMainLayer(config, name="model") - - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None, *model_args, **kwargs): - if pretrained_model_name_or_path == "facebook/blenderbot-90M": - from ..blenderbot_small import TFBlenderbotSmallModel - - warnings.warn( - "The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical" - " checkpoint `facebook/small_blenderbot-90M` with" - " `TFBlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')`" - " instead.", - FutureWarning, - ) - return TFBlenderbotSmallModel.from_pretrained(pretrained_model_name_or_path) - - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - - @unpack_inputs - @add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSeq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - decoder_position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values: list[tf.Tensor] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs, - ) -> tuple[tf.Tensor] | TFSeq2SeqModelOutput: - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqModelOutput( - last_hidden_state=output.last_hidden_state, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer -class BiasLayer(keras.layers.Layer): - """ - Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, - so all weights have to be registered in a layer. - """ - - def __init__(self, shape, initializer, trainable, name, **kwargs): - super().__init__(name=name, **kwargs) - # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of - # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: - # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 - self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) - - def call(self, x): - return x + self.bias - - -@add_start_docstrings( - "The BLENDERBOT Model with a language modeling head. Can be used for summarization.", - BLENDERBOT_START_DOCSTRING, -) -class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausalLanguageModelingLoss): - _keys_to_ignore_on_load_unexpected = [ - r"model.encoder.embed_tokens.weight", - r"model.decoder.embed_tokens.weight", - ] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.model = TFBlenderbotMainLayer(config, name="model") - self.use_cache = config.use_cache - # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False - ) - - def get_decoder(self): - return self.model.decoder - - def get_encoder(self): - return self.model.encoder - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - def get_bias(self): - return {"final_logits_bias": self.bias_layer.bias} - - def set_bias(self, value): - # Replaces the existing layers containing bias for correct (de)serialization. - vocab_size = value["final_logits_bias"].shape[-1] - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False - ) - self.bias_layer.bias.assign(value["final_logits_bias"]) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None, *model_args, **kwargs): - if pretrained_model_name_or_path == "facebook/blenderbot-90M": - from ..blenderbot_small import TFBlenderbotSmallForConditionalGeneration - - warnings.warn( - "The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical" - " checkpoint `facebook/small_blenderbot-90M` with" - " `TFBlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')`" - " instead.", - FutureWarning, - ) - return TFBlenderbotSmallForConditionalGeneration.from_pretrained(pretrained_model_name_or_path) - - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - - @unpack_inputs - @add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - @add_end_docstrings(BLENDERBOT_GENERATION_EXAMPLE) - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - decoder_position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values: list[tf.Tensor] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor] | TFSeq2SeqLMOutput: - r""" - labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - """ - if labels is not None: - labels = tf.where( - labels == self.config.pad_token_id, - tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), - labels, - ) - use_cache = False - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) - lm_logits = self.bias_layer(lm_logits) - masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - return TFSeq2SeqLMOutput( - loss=masked_lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, # index 1 of d outputs - decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs - decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs - cross_attentions=outputs.cross_attentions, # index 4 of d outputs - encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs - encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out - encoder_attentions=outputs.encoder_attentions, # 2 of e out - ) - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - if decoder_attention_mask is not None: # xla - decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] - else: # no xla + no past_key_values - decoder_position_ids = tf.range(decoder_input_ids.shape[1]) - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "decoder_position_ids": decoder_position_ids, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "bias_layer", None) is not None: - with tf.name_scope(self.bias_layer.name): - self.bias_layer.build(None) - - -__all__ = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel", "TFBlenderbotPreTrainedModel"] diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py deleted file mode 100644 index ac30320bbdb4..000000000000 --- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py +++ /dev/null @@ -1,1528 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Facebook, Inc. and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax BlenderbotSmall model.""" - -import math -import random -from functools import partial -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxSeq2SeqLMOutput, - FlaxSeq2SeqModelOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, logging, replace_return_docstrings -from .configuration_blenderbot_small import BlenderbotSmallConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/blenderbot_small-90M" -_CONFIG_FOR_DOC = "BlenderbotSmallConfig" - -BLENDERBOT_SMALL_START_DOCSTRING = r""" - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`BlenderbotSmallConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -BLENDERBOT_SMALL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -BLENDERBOT_SMALL_ENCODE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -BLENDERBOT_SMALL_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: - """ - Shift input ids one token to the right. - """ - shifted_input_ids = jnp.zeros_like(input_ids) - shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) - shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) - - shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) - return shifted_input_ids - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->BlenderbotSmall -class FlaxBlenderbotSmallAttention(nn.Module): - config: BlenderbotSmallConfig - embed_dim: int - num_heads: int - dropout: float = 0.0 - causal: bool = False - bias: bool = True - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self) -> None: - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {self.num_heads})." - ) - - dense = partial( - nn.Dense, - self.embed_dim, - use_bias=self.bias, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() - self.out_proj = dense() - - self.dropout_layer = nn.Dropout(rate=self.dropout) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states: jnp.ndarray, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states) - value_states = self.v_proj(key_value_states) - else: - # self_attention - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayer with Bart->BlenderbotSmall -class FlaxBlenderbotSmallEncoderLayer(nn.Module): - config: BlenderbotSmallConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxBlenderbotSmallAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.encoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) - - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->BlenderbotSmall -class FlaxBlenderbotSmallEncoderLayerCollection(nn.Module): - config: BlenderbotSmallConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxBlenderbotSmallEncoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.encoder_layers) - ] - self.layerdrop = self.config.encoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): # skip the layer - layer_outputs = (None, None) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions, - deterministic, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayer with Bart->BlenderbotSmall -class FlaxBlenderbotSmallDecoderLayer(nn.Module): - config: BlenderbotSmallConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxBlenderbotSmallAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - causal=True, - dtype=self.dtype, - ) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.encoder_attn = FlaxBlenderbotSmallAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.fc1 = nn.Dense( - self.config.decoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - hidden_states, cross_attn_weights = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # Fully Connected - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->BlenderbotSmall -class FlaxBlenderbotSmallDecoderLayerCollection(nn.Module): - config: BlenderbotSmallConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxBlenderbotSmallDecoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.decoder_layers) - ] - self.layerdrop = self.config.decoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): - layer_outputs = (None, None, None) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - deterministic=deterministic, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class FlaxBlenderbotSmallEncoder(nn.Module): - config: BlenderbotSmallConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_source_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 - - self.embed_positions = nn.Embed( - self.config.max_position_embeddings, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.layers = FlaxBlenderbotSmallEncoderLayerCollection(self.config, self.dtype) - self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(position_ids) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return outputs - - return FlaxBaseModelOutput( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxBlenderbotSmallDecoder(nn.Module): - config: BlenderbotSmallConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_target_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 - - self.embed_positions = nn.Embed( - self.config.max_position_embeddings, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.layers = FlaxBlenderbotSmallDecoderLayerCollection(self.config, self.dtype) - self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - # embed positions - positions = self.embed_positions(position_ids) - - # BlenderbotSmall applies layer norm on inputs_embeds in decoder - inputs_embeds = self.layernorm_embedding(inputs_embeds) - hidden_states = inputs_embeds + positions - - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return outputs - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->BlenderbotSmall -class FlaxBlenderbotSmallModule(nn.Module): - config: BlenderbotSmallConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - - self.encoder = FlaxBlenderbotSmallEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - self.decoder = FlaxBlenderbotSmallDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel): - config_class = BlenderbotSmallConfig - base_model_prefix: str = "model" - module_class: nn.Module = None - - def __init__( - self, - config: BlenderbotSmallConfig, - input_shape: tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - # make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule - input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) - attention_mask = jnp.ones_like(input_ids) - decoder_input_ids = input_ids - decoder_attention_mask = jnp.ones_like(input_ids) - - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(BLENDERBOT_SMALL_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=BlenderbotSmallConfig) - def encode( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxBlenderbotSmallForConditionalGeneration - - >>> model = FlaxBlenderbotSmallForConditionalGeneration.from_pretrained("facebook/blenderbot_small-90M") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_ids, attention_mask, position_ids, **kwargs) - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - @add_start_docstrings(BLENDERBOT_SMALL_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=BlenderbotSmallConfig - ) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoTokenizer, FlaxBlenderbotSmallForConditionalGeneration - - >>> model = FlaxBlenderbotSmallForConditionalGeneration.from_pretrained("facebook/blenderbot_small-90M") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> last_decoder_hidden_states = outputs.last_hidden_state - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxBlenderbotSmallAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # prepare decoder inputs - if decoder_input_ids is None: - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - if decoder_position_ids is None: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - -@add_start_docstrings( - "The bare BlenderbotSmall Model transformer outputting raw hidden-states without any specific head on top.", - BLENDERBOT_SMALL_START_DOCSTRING, -) -class FlaxBlenderbotSmallModel(FlaxBlenderbotSmallPreTrainedModel): - config: BlenderbotSmallConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - module_class = FlaxBlenderbotSmallModule - - -append_call_sample_docstring(FlaxBlenderbotSmallModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->BlenderbotSmall -class FlaxBlenderbotSmallForConditionalGenerationModule(nn.Module): - config: BlenderbotSmallConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.model = FlaxBlenderbotSmallModule(config=self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.model.shared.num_embeddings, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = self.model.variables["params"]["shared"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype)) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return output - - return FlaxSeq2SeqLMOutput( - logits=lm_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - "The BLENDERBOT_SMALL Model with a language modeling head. Can be used for summarization.", - BLENDERBOT_SMALL_START_DOCSTRING, -) -class FlaxBlenderbotSmallForConditionalGeneration(FlaxBlenderbotSmallPreTrainedModel): - module_class = FlaxBlenderbotSmallForConditionalGenerationModule - dtype: jnp.dtype = jnp.float32 - - @add_start_docstrings(BLENDERBOT_SMALL_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=BlenderbotSmallConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - deterministic: bool = True, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoTokenizer, FlaxBlenderbotSmallForConditionalGeneration - - >>> model = FlaxBlenderbotSmallForConditionalGeneration.from_pretrained("facebook/blenderbot_small-90M") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxBlenderbotSmallAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - outputs = decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = module.model.variables["params"]["shared"]["embedding"] - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = module.lm_head(hidden_states) - - lm_logits += module.final_logits_bias.astype(self.dtype) - return lm_logits, outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 - return model_kwargs - - -FLAX_BLENDERBOT_SMALL_CONDITIONAL_GENERATION_DOCSTRING = """ - Returns: - - Summarization example: - - ```py - >>> from transformers import AutoTokenizer, FlaxBlenderbotSmallForConditionalGeneration - - >>> model = FlaxBlenderbotSmallForConditionalGeneration.from_pretrained("facebook/blenderbot_small-90M") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M") - - >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np") - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs["input_ids"]).sequences - >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) - ``` - - Mask filling example: - - ```py - >>> from transformers import AutoTokenizer, FlaxBlenderbotSmallForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M") - >>> TXT = "My friends are but they eat too many carbs." - - >>> model = FlaxBlenderbotSmallForConditionalGeneration.from_pretrained("facebook/blenderbot_small-90M") - >>> input_ids = tokenizer([TXT], return_tensors="np")["input_ids"] - >>> logits = model(input_ids).logits - - >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() - >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) - >>> values, predictions = jax.lax.top_k(probs) - - >>> tokenizer.decode(predictions).split() - ``` -""" - -overwrite_call_docstring( - FlaxBlenderbotSmallForConditionalGeneration, - BLENDERBOT_SMALL_INPUTS_DOCSTRING + FLAX_BLENDERBOT_SMALL_CONDITIONAL_GENERATION_DOCSTRING, -) -append_replace_return_docstrings( - FlaxBlenderbotSmallForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC -) - - -__all__ = [ - "FlaxBlenderbotSmallForConditionalGeneration", - "FlaxBlenderbotSmallModel", - "FlaxBlenderbotSmallPreTrainedModel", -] diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py deleted file mode 100644 index be7711801ed2..000000000000 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ /dev/null @@ -1,1527 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Facebook, Inc and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 BlenderbotSmall model.""" - -from __future__ import annotations - -import random - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPastAndCrossAttentions, - TFSeq2SeqLMOutput, - TFSeq2SeqModelOutput, -) - -# Public API -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFPreTrainedModel, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_end_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_blenderbot_small import BlenderbotSmallConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/blenderbot_small-90M" -_CONFIG_FOR_DOC = "BlenderbotSmallConfig" - - -LARGE_NEGATIVE = -1e8 - - -# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right -def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - pad_token_id = tf.cast(pad_token_id, input_ids.dtype) - decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) - start_tokens = tf.fill( - (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) - ) - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz = input_ids_shape[0] - tgt_len = input_ids_shape[1] - mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE - mask_cond = tf.range(shape_list(mask)[-1]) - - mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) - - return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -# Copied from transformers.models.blenderbot.modeling_tf_blenderbot.TFBlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall -class TFBlenderbotSmallLearnedPositionalEmbedding(keras.layers.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): - super().__init__(num_embeddings, embedding_dim, **kwargs) - - def call( - self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None - ): - """Input is expected to be of size [bsz x seqlen].""" - if position_ids is None: - seq_len = input_shape[1] - position_ids = tf.range(seq_len, delta=1, name="range") - position_ids += past_key_values_length - - return super().call(tf.cast(position_ids, dtype=tf.int32)) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->BlenderbotSmall -class TFBlenderbotSmallAttention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartEncoderLayer with Bart->BlenderbotSmall -class TFBlenderbotSmallEncoderLayer(keras.layers.Layer): - def __init__(self, config: BlenderbotSmallConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFBlenderbotSmallAttention( - self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" - ) - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: np.ndarray | tf.Tensor | None, - layer_head_mask: tf.Tensor | None, - training: bool | None = False, - ) -> tf.Tensor: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)` - """ - residual = hidden_states - hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask - ) - - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) - - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - return hidden_states, self_attn_weights - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.encoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartDecoderLayer with Bart->BlenderbotSmall -class TFBlenderbotSmallDecoderLayer(keras.layers.Layer): - def __init__(self, config: BlenderbotSmallConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFBlenderbotSmallAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - name="self_attn", - is_decoder=True, - ) - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.encoder_attn = TFBlenderbotSmallAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - name="encoder_attn", - is_decoder=True, - ) - self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") - self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - cross_attn_layer_head_mask: tf.Tensor | None = None, - past_key_value: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - encoder_hidden_states (`tf.Tensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` - encoder_attention_mask (`tf.Tensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - `(decoder_attention_heads,)` - cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. - `(decoder_attention_heads,)` - past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states - """ - residual = hidden_states - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - return ( - hidden_states, - self_attn_weights, - cross_attn_weights, - present_key_value, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "encoder_attn", None) is not None: - with tf.name_scope(self.encoder_attn.name): - self.encoder_attn.build(None) - if getattr(self, "encoder_attn_layer_norm", None) is not None: - with tf.name_scope(self.encoder_attn_layer_norm.name): - self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.decoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFBlenderbotSmallPreTrainedModel(TFPreTrainedModel): - config_class = BlenderbotSmallConfig - base_model_prefix = "model" - - -BLENDERBOT_SMALL_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`BlenderbotSmallConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -BLENDERBOT_SMALL_GENERATION_EXAMPLE = r""" - Conversation example:: - - ```py - >>> from transformers import AutoTokenizer, TFBlenderbotSmallForConditionalGeneration - - >>> mname = "facebook/blenderbot_small-90M" - >>> model = BlenderbotSmallForConditionalGeneration.from_pretrained(mname) - >>> tokenizer = AutoTokenizer.from_pretrained(mname) - - >>> UTTERANCE = "My friends are cool but they eat too many carbs." - >>> print("Human: ", UTTERANCE) - >>> inputs = tokenizer([UTTERANCE], return_tensors="tf") - - >>> reply_ids = model.generate(**inputs) - >>> print("Bot: ", tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]) - what kind of carbs do they eat? i don't know much about carbs. - - >>> REPLY = "I'm not sure" - >>> print("Human: ", REPLY) - >>> NEXT_UTTERANCE = ( - ... "My friends are cool but they eat too many carbs. " - ... "what kind of carbs do they eat? i don't know much about carbs. " - ... "I'm not sure." - ... ) - - >>> inputs = tokenizer([NEXT_UTTERANCE], return_tensors="tf") - >>> inputs.pop("token_type_ids") - >>> next_reply_ids = model.generate(**inputs) - >>> print("Bot: ", tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0]) - ``` -""" - -BLENDERBOT_SMALL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - BlenderbotSmall uses the `bos_token_id` as the starting token for `decoder_input_ids` generation. If - `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. - decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tf.FloatTensor`, *optional*): - hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - of shape `(batch_size, sequence_length, hidden_size)` is a sequence of - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@keras_serializable -class TFBlenderbotSmallEncoder(keras.layers.Layer): - config_class = BlenderbotSmallConfig - """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`TFBlenderbotSmallEncoderLayer`]. - - Args: - config: BlenderbotSmallConfig - """ - - def __init__(self, config: BlenderbotSmallConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.dropout = keras.layers.Dropout(config.dropout) - self.layerdrop = config.encoder_layerdrop - self.padding_idx = config.pad_token_id - self.max_source_positions = config.max_position_embeddings - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - - self.embed_tokens = embed_tokens - self.embed_positions = TFBlenderbotSmallLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.layers = [TFBlenderbotSmallEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] - self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") - self.embed_dim = config.d_model - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - """ - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value - in the config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. This argument can be used only in eager mode, in graph mode the value in the config - will be used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(input_shape) - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - - # check attention mask and invert - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask) - else: - attention_mask = None - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - tf.debugging.assert_equal( - shape_list(head_mask)[0], - len(self.layers), - message=( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(head_mask)[0]}." - ), - ) - - # encoder layers - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): # skip the layer - continue - - hidden_states, attn = encoder_layer( - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - ) - - if output_attentions: - all_attentions += (attn,) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layernorm_embedding", None) is not None: - with tf.name_scope(self.layernorm_embedding.name): - self.layernorm_embedding.build([None, None, self.embed_dim]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFBlenderbotSmallDecoder(keras.layers.Layer): - config_class = BlenderbotSmallConfig - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFBlenderbotSmallDecoderLayer`] - - Args: - config: BlenderbotSmallConfig - embed_tokens: output embedding - """ - - def __init__(self, config: BlenderbotSmallConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.padding_idx = config.pad_token_id - self.embed_tokens = embed_tokens - self.layerdrop = config.decoder_layerdrop - self.embed_positions = TFBlenderbotSmallLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - self.layers = [TFBlenderbotSmallDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] - self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") - - self.dropout = keras.layers.Dropout(config.dropout) - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - position_ids=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up - decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value - in the config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. This argument can be used only in eager mode, in graph mode the value in the config - will be used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) - else: - combined_attention_mask = _expand_mask( - tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] - ) - - if attention_mask is not None: - combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) - - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) - - # embed positions - if position_ids is None: - positions = self.embed_positions(input_shape, past_key_values_length) - else: - positions = self.embed_positions(input_shape, position_ids=position_ids) - - hidden_states = self.layernorm_embedding(inputs_embeds) + positions - hidden_states = self.dropout(hidden_states, training=training) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None - present_key_values = () if use_cache else None - - # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired - for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: - if attn_mask is not None: - tf.debugging.assert_equal( - shape_list(attn_mask)[0], - len(self.layers), - message=( - f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(attn_mask)[0]}." - ), - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - dropout_probability = random.uniform(0, 1) - - if training and (dropout_probability < self.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - past_key_value=past_key_value, - ) - - if use_cache: - present_key_values += (present_key_value,) - - if output_attentions: - all_self_attns += (layer_self_attn,) - - if encoder_hidden_states is not None: - all_cross_attns += (layer_cross_attn,) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns - else: - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layernorm_embedding", None) is not None: - with tf.name_scope(self.layernorm_embedding.name): - self.layernorm_embedding.build([None, None, self.config.d_model]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFBlenderbotSmallMainLayer(keras.layers.Layer): - config_class = BlenderbotSmallConfig - - def __init__(self, config: BlenderbotSmallConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.shared = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.d_model, - embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), - name="model.shared", - ) - # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) - self.shared.load_weight_prefix = "model.shared" - - self.encoder = TFBlenderbotSmallEncoder(config, self.shared, name="encoder") - self.decoder = TFBlenderbotSmallDecoder(config, self.shared, name="decoder") - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.embed_tokens = self.shared - self.decoder.embed_tokens = self.shared - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - decoder_position_ids=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values=None, - inputs_embeds=None, - decoder_inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - **kwargs, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): - encoder_outputs = TFBaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not return_dict and not isinstance(encoder_outputs, tuple): - encoder_outputs = encoder_outputs.to_tuple() - - decoder_outputs = self.decoder( - decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return TFSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - # The shared/tied weights expect to be in the model base namespace - # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than - # the current one. - with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): - self.shared.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - "The bare BLENDERBOT_SMALL Model outputting raw hidden-states without any specific head on top.", - BLENDERBOT_SMALL_START_DOCSTRING, -) -class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): - def __init__(self, config: BlenderbotSmallConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.model = TFBlenderbotSmallMainLayer(config, name="model") - - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(BLENDERBOT_SMALL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSeq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - decoder_position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values: list[tf.Tensor] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs, - ) -> tuple[tf.Tensor] | TFSeq2SeqModelOutput: - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqModelOutput( - last_hidden_state=output.last_hidden_state, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer -class BiasLayer(keras.layers.Layer): - """ - Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, - so all weights have to be registered in a layer. - """ - - def __init__(self, shape, initializer, trainable, name, **kwargs): - super().__init__(name=name, **kwargs) - # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of - # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: - # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 - self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) - - def call(self, x): - return x + self.bias - - -@add_start_docstrings( - "The BLENDERBOT_SMALL Model with a language modeling head. Can be used for summarization.", - BLENDERBOT_SMALL_START_DOCSTRING, -) -class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel, TFCausalLanguageModelingLoss): - _keys_to_ignore_on_load_unexpected = [ - r"model.encoder.embed_tokens.weight", - r"model.decoder.embed_tokens.weight", - ] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.model = TFBlenderbotSmallMainLayer(config, name="model") - self.use_cache = config.use_cache - # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False - ) - - def get_decoder(self): - return self.model.decoder - - def get_encoder(self): - return self.model.encoder - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - def get_bias(self): - return {"final_logits_bias": self.bias_layer.bias} - - def set_bias(self, value): - # Replaces the existing layers containing bias for correct (de)serialization. - vocab_size = value["final_logits_bias"].shape[-1] - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False - ) - self.bias_layer.bias.assign(value["final_logits_bias"]) - - @unpack_inputs - @add_start_docstrings_to_model_forward(BLENDERBOT_SMALL_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - @add_end_docstrings(BLENDERBOT_SMALL_GENERATION_EXAMPLE) - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - decoder_position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - encoder_outputs: TFBaseModelOutput | None = None, - past_key_values: list[tf.Tensor] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor] | TFSeq2SeqLMOutput: - r""" - labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - """ - - if labels is not None: - labels = tf.where( - labels == self.config.pad_token_id, - tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), - labels, - ) - use_cache = False - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) - lm_logits = self.bias_layer(lm_logits) - masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - return TFSeq2SeqLMOutput( - loss=masked_lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, # index 1 of d outputs - decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs - decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs - cross_attentions=outputs.cross_attentions, # index 4 of d outputs - encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs - encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out - encoder_attentions=outputs.encoder_attentions, # 2 of e out - ) - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - if decoder_attention_mask is not None: # xla - decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] - else: # no xla + no past_key_values - decoder_position_ids = tf.range(decoder_input_ids.shape[1]) - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "decoder_position_ids": decoder_position_ids, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "bias_layer", None) is not None: - with tf.name_scope(self.bias_layer.name): - self.bias_layer.build(None) - - -__all__ = ["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"] diff --git a/src/transformers/models/blip/modeling_tf_blip.py b/src/transformers/models/blip/modeling_tf_blip.py deleted file mode 100644 index a1a1f7928273..000000000000 --- a/src/transformers/models/blip/modeling_tf_blip.py +++ /dev/null @@ -1,1709 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Salesforce Team Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow BLIP model.""" - -from __future__ import annotations - -import warnings -from dataclasses import dataclass -from typing import Any - -import tensorflow as tf - -from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling -from ...modeling_tf_utils import ( - TFPreTrainedModel, - get_initializer, - get_tf_activation, - keras, - keras_serializable, - shape_list, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, stable_softmax -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig -from .modeling_tf_blip_text import BLIP_TEXT_INPUTS_DOCSTRING, TFBlipTextLMHeadModel, TFBlipTextModel - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "Salesforce/blip-vqa-base" - - -# Copied from transformers.models.clip.modeling_tf_clip.contrastive_loss -def contrastive_loss(logits: tf.Tensor) -> tf.Tensor: - return tf.math.reduce_mean( - keras.metrics.sparse_categorical_crossentropy( - y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True - ) - ) - - -# Copied from transformers.models.clip.modeling_tf_clip.clip_loss with clip->blip -def blip_loss(similarity: tf.Tensor) -> tf.Tensor: - caption_loss = contrastive_loss(similarity) - image_loss = contrastive_loss(tf.transpose(similarity)) - return (caption_loss + image_loss) / 2.0 - - -@dataclass -class TFBlipForConditionalGenerationModelOutput(ModelOutput): - """ - Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the - last hidden states. This class also adds the loss term from the text decoder. - - Args: - loss (`tf.Tensor`, *optional*, returned when `labels` is provided, `tf.Tensor` of shape `(1,)`): - Language modeling loss from the text decoder. - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): - Prediction scores of the language modeling head of the text decoder model. - image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)`, *optional*): - The image embeddings obtained after applying the Vision Transformer model to the input image. - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads.` - """ - - loss: tuple[tf.Tensor] | None = None - logits: tuple[tf.Tensor] | None = None - image_embeds: tf.Tensor | None = None - last_hidden_state: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - @property - def decoder_logits(self): - warnings.warn( - "`decoder_logits` attribute is deprecated and will be removed in version 5 of Transformers." - " Please use the `logits` attribute to retrieve the final output instead.", - FutureWarning, - ) - return self.logits - - -@dataclass -class TFBlipTextVisionModelOutput(ModelOutput): - """ - Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the - last hidden states. This class also adds the loss term from the text decoder. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss from the text decoder. - image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - image_embeds: tf.Tensor | None = None - last_hidden_state: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFBlipImageTextMatchingModelOutput(ModelOutput): - """ - Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the - last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity - scores. - - Args: - itm_score (`tf.Tensor`): - The image-text similarity scores. - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss from the text decoder. - image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - vision_pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`, *optional*): - Last layer hidden-state of the vision of the vision-only branch of the model. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - question_embeds (`tf.Tensor`): - The question embeddings obtained by the text projection layer. - """ - - itm_score: tf.Tensor | None = None - loss: tf.Tensor | None = None - image_embeds: tf.Tensor | None = None - last_hidden_state: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - vision_pooler_output: tf.Tensor | None = None - attentions: tuple[tf.Tensor, ...] | None = None - question_embeds: tuple[tf.Tensor] | None = None - - -@dataclass -class TFBlipOutput(ModelOutput): - """ - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): - Contrastive loss for image-text similarity. - logits_per_image:(`tf.Tensor` of shape `(image_batch_size, text_batch_size)`): - The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text - similarity scores. - logits_per_text:(`tf.Tensor` of shape `(text_batch_size, image_batch_size)`): - The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image - similarity scores. - text_embeds(`tf.Tensor` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`]. - image_embeds(`tf.Tensor` of shape `(batch_size, output_dim`): - The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`]. - text_model_output(`BaseModelOutputWithPooling`): - The output of the [`BlipTextModel`]. - vision_model_output(`BaseModelOutputWithPooling`): - The output of the [`BlipVisionModel`]. - """ - - loss: tf.Tensor | None = None - logits_per_image: tf.Tensor | None = None - logits_per_text: tf.Tensor | None = None - text_embeds: tf.Tensor | None = None - image_embeds: tf.Tensor | None = None - text_model_output: TFBaseModelOutputWithPooling = None - vision_model_output: TFBaseModelOutputWithPooling = None - - def to_tuple(self) -> tuple[Any]: - return tuple( - self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() - for k in self.keys() - ) - - -class TFBlipVisionEmbeddings(keras.layers.Layer): - def __init__(self, config: BlipVisionConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.patch_embedding = keras.layers.Conv2D( - filters=self.embed_dim, - kernel_size=self.patch_size, - strides=self.patch_size, - kernel_initializer=get_initializer(self.config.initializer_range), - data_format="channels_last", - name="patch_embedding", - ) - - self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches + 1 - - def build(self, input_shape=None): - self.class_embedding = self.add_weight( - shape=(1, 1, self.embed_dim), - initializer=get_initializer(self.config.initializer_range), - trainable=True, - name="class_embedding", - ) - - self.position_embedding = self.add_weight( - shape=(1, self.num_positions, self.embed_dim), - initializer=get_initializer(self.config.initializer_range), - trainable=True, - name="position_embedding", - ) - - if self.built: - return - self.built = True - if getattr(self, "patch_embedding", None) is not None: - with tf.name_scope(self.patch_embedding.name): - self.patch_embedding.build([None, None, None, 3]) - - def call(self, pixel_values: tf.Tensor) -> tf.Tensor: - # Input is channels-first, we transpose. PyTorch transposes after the conv because PyTorch - # likes channels-first convs. - batch_size = tf.shape(pixel_values)[0] - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - patch_embeds = self.patch_embedding(pixel_values) - patch_embeds = tf.reshape(patch_embeds, (batch_size, self.num_patches, -1)) - - class_embeds = tf.broadcast_to(self.class_embedding, (batch_size, 1, self.embed_dim)) - embeddings = tf.concat([class_embeds, patch_embeds], axis=1) - embeddings = embeddings + self.position_embedding[:, : tf.shape(embeddings)[1], :] - return embeddings - - -# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextEmbeddings with CLIP->Blip -class TFBlipTextEmbeddings(keras.layers.Layer): - def __init__(self, config: BlipTextConfig, **kwargs): - super().__init__(**kwargs) - - self.embed_dim = config.hidden_size - - self.config = config - - def build(self, input_shape: tf.TensorShape = None): - with tf.name_scope("token_embedding"): - self.weight = self.add_weight( - shape=(self.config.vocab_size, self.embed_dim), - initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range), - trainable=True, - name="weight", - ) - - with tf.name_scope("position_embedding"): - self.position_embedding = self.add_weight( - shape=(self.config.max_position_embeddings, self.embed_dim), - initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range), - trainable=True, - name="embeddings", - ) - - super().build(input_shape) - - def call( - self, - input_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - if input_ids is None and inputs_embeds is None: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) - - position_embeds = tf.gather(params=self.position_embedding, indices=position_ids) - position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) - final_embeddings = inputs_embeds + position_embeds - - return final_embeddings - - -class TFBlipAttention(keras.layers.Layer): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - self.dropout = keras.layers.Dropout(config.attention_dropout, name="dropout") - - self.qkv = keras.layers.Dense( - 3 * self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="qkv" - ) - - self.projection = keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="projection" - ) - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = False, - training: bool | None = None, - ) -> tuple[tf.Tensor, tf.Tensor | None, tuple[tf.Tensor] | None]: - """Input shape: Batch x Time x Channel""" - - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - mixed_qkv = self.qkv(hidden_states) - mixed_qkv = tf.reshape(mixed_qkv, (bsz, tgt_len, 3, self.num_heads, self.head_dim)) - mixed_qkv = tf.transpose(mixed_qkv, perm=(2, 0, 3, 1, 4)) - - query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = query_states @ tf.transpose(key_states, (0, 1, 3, 2)) - - attention_scores = attention_scores * self.scale - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = tf.transpose(attention_probs @ value_states, perm=(0, 2, 1, 3)) - - new_context_layer_shape = shape_list(context_layer)[:-2] + [self.embed_dim] - context_layer = tf.reshape(context_layer, new_context_layer_shape) - - output = self.projection(context_layer) - - outputs = (output, attention_probs) if output_attentions else (output, None) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "qkv", None) is not None: - with tf.name_scope(self.qkv.name): - self.qkv.build([None, None, self.embed_dim]) - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, self.embed_dim]) - - -class TFBlipMLP(keras.layers.Layer): - def __init__(self, config: BlipConfig, **kwargs): - super().__init__(**kwargs) - - self.activation_fn = get_tf_activation(config.hidden_act) - - in_proj_std = (config.hidden_size**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) - fc_std = (2 * config.hidden_size) ** -0.5 - - self.fc1 = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(fc_std), name="fc1" - ) - self.fc2 = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(in_proj_std), name="fc2" - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.fc1(inputs=hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(inputs=hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.config.hidden_size]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.intermediate_size]) - - -class TFBlipEncoderLayer(keras.layers.Layer): - def __init__(self, config: BlipConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.hidden_size - self.self_attn = TFBlipAttention(config, name="self_attn") - self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") - self.mlp = TFBlipMLP(config, name="mlp") - self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - output_attentions: bool | None = False, - training: bool | None = None, - ) -> tuple[tf.Tensor]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - `(config.encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - head_mask=attention_mask, - output_attentions=output_attentions, - training=training, - ) - hidden_states = hidden_states + residual - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - - hidden_states = hidden_states + residual - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "layer_norm1", None) is not None: - with tf.name_scope(self.layer_norm1.name): - self.layer_norm1.build([None, None, self.embed_dim]) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "layer_norm2", None) is not None: - with tf.name_scope(self.layer_norm2.name): - self.layer_norm2.build([None, None, self.embed_dim]) - - -class TFBlipPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BlipConfig - base_model_prefix = "blip" - _keys_to_ignore_on_load_missing = [r"position_ids"] - - -BLIP_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`BlipConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -BLIP_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -BLIP_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details. - return_loss (`bool`, *optional*): - Whether or not to return the contrastive loss. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@keras_serializable -class TFBlipEncoder(keras.layers.Layer): - config_class = BlipConfig - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`BlipEncoderLayer`]. - - Args: - config (`BlipConfig`): - The corresponding vision configuration for the `BlipEncoder`. - """ - - def __init__(self, config: BlipConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layers = [TFBlipEncoderLayer(config, name=f"layers_._{i}") for i in range(config.num_hidden_layers)] - - @unpack_inputs - def call( - self, - inputs_embeds, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = None, - ) -> tuple | TFBaseModelOutput: - r""" - Args: - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Embedded representation of the inputs. Should be float, not int tokens. - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - training=training, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFBlipVisionModel(TFBlipPreTrainedModel): - main_input_name = "pixel_values" - config_class = BlipVisionConfig - - def __init__(self, config: BlipVisionConfig, *args, **kwargs): - super().__init__(config, *args, **kwargs) - self.config = config - - self.embeddings = TFBlipVisionEmbeddings(config, name="embeddings") - self.encoder = TFBlipEncoder(config, name="encoder") - self.post_layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm") - self.embed_dim = config.hidden_size - - def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: - hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None - attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None - - return TFBaseModelOutputWithPooling( - last_hidden_state=output.last_hidden_state, - pooler_output=output.pooler_output, - hidden_states=hs, - attentions=attns, - ) - - @unpack_inputs - @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=BlipVisionConfig) - def call( - self, - pixel_values: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = None, - ) -> tuple | TFBaseModelOutputWithPooling: - r""" - Returns: - - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - hidden_states = self.embeddings(pixel_values) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - last_hidden_state = encoder_outputs[0] - last_hidden_state = self.post_layernorm(last_hidden_state) - - pooled_output = last_hidden_state[:, 0, :] - # TF gets confused if we call the layer with inputs of different ranks, so insert a singleton dimension - pooled_output = self.post_layernorm(tf.expand_dims(pooled_output, 1)) - pooled_output = tf.squeeze(pooled_output, 1) - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def get_input_embeddings(self): - return self.embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "post_layernorm", None) is not None: - with tf.name_scope(self.post_layernorm.name): - self.post_layernorm.build([None, None, self.embed_dim]) - - -class TFBlipMainLayer(keras.layers.Layer): - config_class = BlipConfig - - def __init__(self, config: BlipConfig, *args, **kwargs): - super().__init__(*args, **kwargs) - - if not isinstance(config.text_config, BlipTextConfig): - raise TypeError( - "config.text_config is expected to be of type BlipTextConfig but is of type" - f" {type(config.text_config)}." - ) - - if not isinstance(config.vision_config, BlipVisionConfig): - raise TypeError( - "config.vision_config is expected to be of type BlipVisionConfig but is of type" - f" {type(config.vision_config)}." - ) - - text_config = config.text_config - vision_config = config.vision_config - - self.projection_dim = config.projection_dim - self.text_embed_dim = text_config.hidden_size - self.vision_embed_dim = vision_config.hidden_size - - self.text_model = TFBlipTextModel(text_config, name="text_model") - self.vision_model = TFBlipVisionModel(vision_config, name="vision_model") - - self.visual_projection = keras.layers.Dense( - self.projection_dim, - use_bias=False, - kernel_initializer=get_initializer(config.initializer_range), - name="visual_projection", - ) - self.text_projection = keras.layers.Dense( - self.projection_dim, - use_bias=False, - kernel_initializer=get_initializer(config.initializer_range), - name="text_projection", - ) - - self.config = config - - def build(self, input_shape=None): - self.logit_scale = self.add_weight( - name="logit_scale", - shape=[], - initializer=keras.initializers.Constant(self.config.logit_scale_init_value), - trainable=True, - ) - - if self.built: - return - self.built = True - if getattr(self, "text_model", None) is not None: - with tf.name_scope(self.text_model.name): - self.text_model.build(None) - if getattr(self, "vision_model", None) is not None: - with tf.name_scope(self.vision_model.name): - self.vision_model.build(None) - if getattr(self, "visual_projection", None) is not None: - with tf.name_scope(self.visual_projection.name): - self.visual_projection.build([None, None, self.vision_embed_dim]) - if getattr(self, "text_projection", None) is not None: - with tf.name_scope(self.text_projection.name): - self.text_projection.build([None, None, self.text_embed_dim]) - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - pixel_values: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - return_loss: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = None, - ) -> tuple | TFBlipOutput: - # Use BLIP model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - text_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - image_embeds = vision_outputs[1] - image_embeds = self.visual_projection(image_embeds) - - text_embeds = text_outputs[1] - text_embeds = self.text_projection(text_embeds) - - # normalized features - image_embeds = image_embeds / tf.norm(image_embeds, ord=2, axis=-1, keepdims=True) - text_embeds = text_embeds / tf.norm(text_embeds, ord=2, axis=-1, keepdims=True) - - # cosine similarity as logits - logit_scale = tf.exp(self.logit_scale) - logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale - logits_per_image = tf.transpose(logits_per_text) - - loss = None - if return_loss: - loss = blip_loss(logits_per_text) - loss = tf.reshape(loss, (1,)) - - if not return_dict: - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output - - return TFBlipOutput( - loss=loss, - logits_per_image=logits_per_image, - logits_per_text=logits_per_text, - text_embeds=text_embeds, - image_embeds=image_embeds, - text_model_output=text_outputs, - vision_model_output=vision_outputs, - ) - - -class TFBlipModel(TFBlipPreTrainedModel): - config_class = BlipConfig - _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"] - main_input_name = "input_ids" - - def __init__(self, config: BlipConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.blip = TFBlipMainLayer(config, name="blip") - - def serving_output(self, output: TFBlipOutput) -> TFBlipOutput: - return TFBlipOutput( - logits_per_image=output.logits_per_image, - logits_per_text=output.logits_per_text, - text_embeds=output.text_embeds, - image_embeds=output.image_embeds, - ) - - @unpack_inputs - @add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBlipOutput, config_class=BlipConfig) - def call( - self, - input_ids: tf.Tensor | None = None, - pixel_values: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - return_loss: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = None, - ) -> tuple | TFBlipOutput: - r""" - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFBlipModel - - >>> model = TFBlipModel.from_pretrained("Salesforce/blip-image-captioning-base") - >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor( - ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="tf", padding=True - ... ) - - >>> outputs = model(**inputs) - >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score - >>> probs = tf.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities - ```""" - outputs = self.blip( - input_ids=input_ids, - pixel_values=pixel_values, - attention_mask=attention_mask, - position_ids=position_ids, - return_loss=return_loss, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - @add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING) - def get_text_features( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - return_dict: bool | None = None, - ) -> tf.Tensor: - r""" - Returns: - text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying - the projection layer to the pooled output of [`TFBlipTextModel`]. - - Examples: - - ```python - >>> from transformers import AutoProcessor, TFBlipModel - - >>> model = TFBlipModel.from_pretrained("Salesforce/blip-image-captioning-base") - >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") - - >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf") - >>> text_features = model.get_text_features(**inputs) - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - text_outputs = self.blip.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=return_dict, - ) - - pooled_output = text_outputs[1] - text_features = self.blip.text_projection(pooled_output) - - return text_features - - @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) - def get_image_features( - self, - pixel_values: tf.Tensor | None = None, - return_dict: bool | None = None, - ) -> tf.Tensor: - r""" - Returns: - image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying - the projection layer to the pooled output of [`TFBlipVisionModel`]. - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFBlipModel - - >>> model = TFBlipModel.from_pretrained("Salesforce/blip-image-captioning-base") - >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="tf") - - >>> image_features = model.get_image_features(**inputs) - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - vision_outputs = self.blip.vision_model(pixel_values=pixel_values, return_dict=return_dict) - - pooled_output = vision_outputs[1] # pooled_output - image_features = self.blip.visual_projection(pooled_output) - - return image_features - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "blip", None) is not None: - with tf.name_scope(self.blip.name): - self.blip.build(None) - - -@add_start_docstrings( - """ - BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass - `input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise, - the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption - from the text input. If no text input is provided, the decoder will start with the [BOS] token only. - """, - BLIP_START_DOCSTRING, -) -class TFBlipForConditionalGeneration(TFBlipPreTrainedModel): - config_class = BlipConfig - _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"] - main_input_name = "pixel_values" - - def __init__(self, config: BlipConfig, *args, **kwargs): - super().__init__(config, *args, **kwargs) - - self.vision_model = TFBlipVisionModel(config.vision_config, name="vision_model") - - self.text_decoder = TFBlipTextLMHeadModel(config.text_config, name="text_decoder") - - self.decoder_input_ids = config.text_config.bos_token_id - self.decoder_pad_token_id = config.text_config.pad_token_id - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.vision_model.embeddings.patch_embedding - - @unpack_inputs - @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBlipForConditionalGenerationModelOutput, config_class=BlipConfig) - def call( - self, - pixel_values: tf.Tensor, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - labels: tf.Tensor | None = None, - return_dict: bool | None = None, - training: bool | None = None, - ) -> tuple | TFBlipForConditionalGenerationModelOutput: - r""" - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFBlipForConditionalGeneration - - >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") - >>> model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - >>> text = "A picture of" - - >>> inputs = processor(images=image, text=text, return_tensors="tf") - - >>> outputs = model(**inputs) - ```""" - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - image_embeds = vision_outputs[0] - - outputs = self.text_decoder( - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=image_embeds, - labels=labels, - return_dict=False, - training=training, - ) - - if not return_dict: - outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:] - return tuple(output for output in outputs if output is not None) - - if labels is not None: - loss = outputs[0] - logits = outputs[1] - else: - loss = None - logits = outputs[0] - - if loss is not None and loss.shape.rank == 0: - loss = tf.reshape(loss, (1,)) - - return TFBlipForConditionalGenerationModelOutput( - loss=loss, - logits=logits, - image_embeds=image_embeds, - last_hidden_state=vision_outputs.last_hidden_state, - hidden_states=vision_outputs.hidden_states, - attentions=vision_outputs.attentions, - ) - - def generate( - self, - pixel_values: tf.Tensor, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - **generate_kwargs, - ) -> tf.Tensor: - r""" - Overrides *generate* function to be able to use the model as a conditional generator - - Parameters: - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, image_height, image_width)`: - Input image to be processed - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - The sequence used as a prompt for the generation. - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - Examples: - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFBlipForConditionalGeneration - - >>> model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") - >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="tf") - - >>> outputs = model.generate(**inputs) - >>> print(processor.decode(outputs[0], skip_special_tokens=True)) - two cats sleeping on a couch - ``` - """ - - batch_size = pixel_values.shape[0] - vision_outputs = self.vision_model(pixel_values=pixel_values) - - image_embeds = vision_outputs[0] - - image_attention_mask = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int32) - - if isinstance(input_ids, list): - input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int32) - elif input_ids is None: - input_ids = tf.convert_to_tensor( - [[self.decoder_input_ids, self.config.text_config.eos_token_id]], dtype=tf.int32 - ) - - input_ids = tf.tile(input_ids, (batch_size, 1)) - - # PyTorch: input_ids[:, 0] = self.config.text_config.bos_token_id - input_ids = tf.concat( - [tf.ones((batch_size, 1), dtype=tf.int32) * self.config.text_config.bos_token_id, input_ids[:, 1:]], axis=1 - ) - attention_mask = attention_mask[:, :-1] if attention_mask is not None else None - - outputs = self.text_decoder.generate( - input_ids=input_ids[:, :-1], - eos_token_id=self.config.text_config.sep_token_id, - pad_token_id=self.config.text_config.pad_token_id, - attention_mask=attention_mask, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, - **generate_kwargs, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "vision_model", None) is not None: - with tf.name_scope(self.vision_model.name): - self.vision_model.build(None) - if getattr(self, "text_decoder", None) is not None: - with tf.name_scope(self.text_decoder.name): - self.text_decoder.build(None) - - -@add_start_docstrings( - """ - BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text - decoder. The vision encoder will encode the input image, the text encoder will encode the input question together - with the encoding of the image, and the text decoder will output the answer to the question. - """, - BLIP_START_DOCSTRING, -) -class TFBlipForQuestionAnswering(TFBlipPreTrainedModel): - config_class = BlipConfig - _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"] - - def __init__(self, config: BlipConfig, *args, **kwargs): - super().__init__(config, *args, **kwargs) - - self.vision_model = TFBlipVisionModel(config.vision_config, name="vision_model") - - self.text_encoder = TFBlipTextModel(config.text_config, name="text_encoder", add_pooling_layer=False) - - self.text_decoder = TFBlipTextLMHeadModel(config.text_config, name="text_decoder") - - self.decoder_pad_token_id = config.text_config.pad_token_id - self.decoder_start_token_id = config.text_config.bos_token_id - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.vision_model.embeddings.patch_embedding - - # Adapted from transformers.models.t5.modeling_tf_t5.TFT5PreTrainedModel._shift_right - def _shift_right(self, input_ids): - decoder_start_token_id = self.decoder_start_token_id - pad_token_id = self.decoder_pad_token_id - - if decoder_start_token_id is None or pad_token_id is None: - raise ValueError("decoder_start_token_id and pad_token_id must be defined!") - - start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) - start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype)) - - return shifted_input_ids - - @unpack_inputs - @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBlipTextVisionModelOutput, config_class=BlipVisionConfig) - def call( - self, - input_ids: tf.Tensor, - pixel_values: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - labels: tf.Tensor | None = None, - return_dict: bool | None = None, - training: bool | None = None, - ) -> tuple | TFBlipTextVisionModelOutput: - r""" - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFBlipForQuestionAnswering - - >>> model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") - >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> # training - >>> text = "How many cats are in the picture?" - >>> label = "2" - >>> inputs = processor(images=image, text=text, return_tensors="tf") - >>> labels = processor(text=label, return_tensors="tf").input_ids - - >>> inputs["labels"] = labels - >>> outputs = model(**inputs) - >>> loss = outputs.loss - - >>> # inference - >>> text = "How many cats are in the picture?" - >>> inputs = processor(images=image, text=text, return_tensors="tf") - >>> outputs = model.generate(**inputs) - >>> print(processor.decode(outputs[0], skip_special_tokens=True)) - 2 - ```""" - if labels is None and decoder_input_ids is None: - raise ValueError( - "Either `decoder_input_ids` or `labels` should be passed when calling" - " `TFBlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you" - " are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`" - ) - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - image_embeds = vision_outputs[0] - image_attention_mask = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int64) - - question_embeds = self.text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, - return_dict=return_dict, - training=training, - ) - - question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state - - if labels is not None and decoder_input_ids is None: - # labels are already shifted right, see: https://github.com/huggingface/transformers/pull/23153 - decoder_input_ids = labels - - answer_output = self.text_decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=question_embeds, - encoder_attention_mask=attention_mask, - labels=labels, - return_dict=return_dict, - training=training, - ) - - if labels is not None: - decoder_loss = tf.reduce_mean(answer_output.loss) if return_dict else tf.reduce_mean(answer_output[0]) - else: - decoder_loss = None - - if not return_dict: - outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:] - return tuple(output for output in outputs if output is not None) - - return TFBlipTextVisionModelOutput( - loss=decoder_loss, - image_embeds=image_embeds, - last_hidden_state=vision_outputs.last_hidden_state, - hidden_states=vision_outputs.hidden_states, - attentions=vision_outputs.attentions, - ) - - def generate( - self, - input_ids: tf.Tensor, - pixel_values: tf.Tensor, - attention_mask: tf.Tensor | None = None, - **generate_kwargs, - ) -> tf.Tensor: - r""" - Overrides *generate* function to be able to use the model as a conditional generator - - Parameters: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, image_height, image_width)`: - Input image to be processed - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for - tokens that are NOT MASKED, `0` for MASKED tokens. - generate_kwargs (dict, *optional*): - Additional arguments passed to the `generate` function of the decoder - - - Examples: - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFBlipForQuestionAnswering - - >>> model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") - >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - >>> text = "How many cats are in the picture?" - - >>> inputs = processor(images=image, text=text, return_tensors="tf") - - >>> outputs = model.generate(**inputs) - >>> print(processor.decode(outputs[0], skip_special_tokens=True)) - 2 - ``` - """ - vision_outputs = self.vision_model(pixel_values=pixel_values) - - image_embeds = vision_outputs[0] - - image_attention_mask = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int32) - - if isinstance(input_ids, list): - input_ids = tf.Tensor(input_ids) - - question_outputs = self.text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, - return_dict=False, - ) - - question_embeds = question_outputs[0] - - question_attention_mask = tf.ones(shape_list(question_embeds)[:-1], dtype=tf.int32) - - bos_ids = tf.fill( - (tf.shape(question_embeds)[0], 1), value=tf.cast(self.decoder_start_token_id, input_ids.dtype) - ) - - outputs = self.text_decoder.generate( - input_ids=bos_ids, - eos_token_id=self.config.text_config.sep_token_id, - pad_token_id=self.config.text_config.pad_token_id, - encoder_hidden_states=question_embeds, - encoder_attention_mask=question_attention_mask, - **generate_kwargs, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "vision_model", None) is not None: - with tf.name_scope(self.vision_model.name): - self.vision_model.build(None) - if getattr(self, "text_encoder", None) is not None: - with tf.name_scope(self.text_encoder.name): - self.text_encoder.build(None) - if getattr(self, "text_decoder", None) is not None: - with tf.name_scope(self.text_decoder.name): - self.text_decoder.build(None) - - -@add_start_docstrings( - """ - BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of - image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to - the image. - """, - BLIP_START_DOCSTRING, -) -class TFBlipForImageTextRetrieval(TFBlipPreTrainedModel): - config_class = BlipConfig - - def __init__(self, config: BlipConfig, *args, **kwargs): - super().__init__(config, *args, **kwargs) - - self.vision_model = TFBlipVisionModel(config.vision_config, name="vision_model") - - self.text_encoder = TFBlipTextModel(config.text_config, name="text_encoder", add_pooling_layer=False) - - # vision projection layer - self.vision_proj = keras.layers.Dense( - config.image_text_hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="vision_proj", - ) - - # text projection layer - self.text_proj = keras.layers.Dense( - config.image_text_hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="text_proj", - ) - - # image text matching head - self.itm_head = keras.layers.Dense( - 2, kernel_initializer=get_initializer(config.initializer_range), name="itm_head" - ) - - self.decoder_pad_token_id = ( - config.text_config.pad_token_id - if not hasattr(config, "decoder_pad_token_id") - else config.decoder_pad_token_id - ) - self.decoder_start_token_id = ( - config.text_config.bos_token_id - if not hasattr(config, "decoder_start_token_id") - else config.decoder_start_token_id - ) - self.config = config - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.vision_model.embeddings.patch_embedding - - @unpack_inputs - @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBlipImageTextMatchingModelOutput, config_class=BlipVisionConfig) - def call( - self, - input_ids: tf.Tensor, - pixel_values: tf.Tensor | None = None, - use_itm_head: bool | None = True, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = None, - ) -> tuple | TFBlipImageTextMatchingModelOutput: - r""" - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFBlipForImageTextRetrieval - - >>> model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco") - >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - >>> text = "an image of a cat" - - >>> inputs = processor(images=image, text=text, return_tensors="tf") - >>> outputs = model(**inputs) - ``` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - image_embeds = vision_outputs[0] - image_atts = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int64) - - # Matt: In PyTorch, only one path (itm/non-itm) is taken. However, in TensorFlow this can result in - # some layers not being built! To avoid this, we always call both paths, then use an if statement to select - # which output to pass to the final output. The unnecessary nodes will be pruned from the final graph, but - # not before the layers have all been built correctly. - itm_question_embeds = self.text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_atts, - return_dict=return_dict, - training=training, - ) - itm_question_embeds = itm_question_embeds[0] if not return_dict else itm_question_embeds.last_hidden_state - - itm_output = self.itm_head(itm_question_embeds[:, 0, :]) - - no_itm_question_embeds = self.text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - return_dict=return_dict, - training=training, - ) - no_itm_question_embeds = ( - no_itm_question_embeds[0] if not return_dict else no_itm_question_embeds.last_hidden_state - ) - - image_feat, _ = tf.linalg.normalize(self.vision_proj(image_embeds[:, 0, :]), ord=2, axis=-1) - text_feat, _ = tf.linalg.normalize(self.text_proj(no_itm_question_embeds[:, 0, :]), ord=2, axis=-1) - - no_itm_output = tf.matmul(image_feat, text_feat, transpose_b=True) - - if use_itm_head: - output = itm_output - question_embeds = itm_question_embeds - else: - output = no_itm_output - question_embeds = no_itm_question_embeds - - if not return_dict: - outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,) - return tuple(output for output in outputs if output is not None) - - return TFBlipImageTextMatchingModelOutput( - itm_score=output, - last_hidden_state=vision_outputs.last_hidden_state, - hidden_states=vision_outputs.hidden_states, - attentions=vision_outputs.attentions, - question_embeds=question_embeds, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "vision_model", None) is not None: - with tf.name_scope(self.vision_model.name): - self.vision_model.build(None) - if getattr(self, "text_encoder", None) is not None: - with tf.name_scope(self.text_encoder.name): - self.text_encoder.build(None) - if getattr(self, "vision_proj", None) is not None: - with tf.name_scope(self.vision_proj.name): - self.vision_proj.build([None, None, self.config.vision_config.hidden_size]) - if getattr(self, "text_proj", None) is not None: - with tf.name_scope(self.text_proj.name): - self.text_proj.build([None, None, self.config.text_config.hidden_size]) - if getattr(self, "itm_head", None) is not None: - with tf.name_scope(self.itm_head.name): - self.itm_head.build([None, None, self.config.text_config.hidden_size]) - - -__all__ = [ - "TFBlipModel", - "TFBlipPreTrainedModel", - "TFBlipForConditionalGeneration", - "TFBlipForQuestionAnswering", - "TFBlipVisionModel", - "TFBlipTextModel", - "TFBlipForImageTextRetrieval", -] diff --git a/src/transformers/models/blip/modeling_tf_blip_text.py b/src/transformers/models/blip/modeling_tf_blip_text.py deleted file mode 100644 index 7dae1126e03b..000000000000 --- a/src/transformers/models/blip/modeling_tf_blip_text.py +++ /dev/null @@ -1,1122 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Salesforce Team Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the BSD-3-clause license (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from __future__ import annotations - -import math - -import tensorflow as tf - -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFBaseModelOutputWithPoolingAndCrossAttentions, - TFCausalLMOutputWithCrossAttentions, -) -from ...modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - get_initializer, - get_tf_activation, - keras, - keras_serializable, - shape_list, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, invert_attention_mask, stable_softmax -from ...utils import add_start_docstrings_to_model_forward, logging -from .configuration_blip import BlipTextConfig - - -logger = logging.get_logger(__name__) - -BLIP_TEXT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L52 -class TFBlipTextEmbeddings(keras.layers.Layer): - """Construct the embeddings from word and position embeddings.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.word_embeddings = keras.layers.Embedding( - config.vocab_size, - config.hidden_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="word_embeddings", - ) - self.position_embeddings = keras.layers.Embedding( - config.max_position_embeddings, - config.hidden_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="position_embeddings", - ) - - # self.LayerNorm is not snake-cased to stick with PyTorch model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout") - - self.position_ids = tf.expand_dims(tf.range(config.max_position_embeddings), 0) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - - self.config = config - - def call(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0, training=None): - if input_ids is not None: - input_shape = tf.shape(input_ids) - else: - input_shape = tf.shape(inputs_embeds)[:-1] - - seq_length = input_shape[1] - - if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = self.word_embeddings(input_ids) - - embeddings = inputs_embeds - - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings, training=training) - return embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "word_embeddings", None) is not None: - with tf.name_scope(self.word_embeddings.name): - self.word_embeddings.build(None) - if getattr(self, "position_embeddings", None) is not None: - with tf.name_scope(self.position_embeddings.name): - self.position_embeddings.build(None) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - - -# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L97 -class TFBlipTextSelfAttention(keras.layers.Layer): - def __init__(self, config, is_cross_attention, **kwargs): - super().__init__(**kwargs) - self.config = config - if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention heads (%d)" - % (config.hidden_size, config.num_attention_heads) - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = keras.layers.Embedding( - 2 * config.max_position_embeddings - 1, self.attention_head_size - ) - self.is_cross_attention = is_cross_attention - - def transpose_for_scores(self, x): - new_x_shape = tf.concat( - [tf.shape(x)[:-1], tf.constant([self.num_attention_heads, self.attention_head_size], dtype=tf.int32)], - axis=0, - ) - x = tf.reshape(x, new_x_shape) - return tf.transpose(x, perm=(0, 2, 1, 3)) - - def call( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - training=None, - ): - mixed_query_layer = self.query(hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = tf.concat([past_key_value[0], key_layer], axis=2) - value_layer = tf.concat([past_key_value[1], value_layer], axis=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = shape_list(hidden_states)[1] - position_ids_l = tf.expand_dims(tf.range(seq_length, dtype=tf.int64, device=hidden_states.device), 1) - position_ids_r = tf.expand_dims(tf.range(seq_length, dtype=tf.int64, device=hidden_states.device), 0) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = tf.cast(positional_embedding, query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = tf.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in BlipTextModel forward() function) - attention_scores = attention_scores + tf.cast(attention_mask, attention_scores.dtype) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs_dropped = self.dropout(attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs_dropped = attention_probs_dropped * head_mask - - context_layer = attention_probs_dropped @ value_layer - - context_layer = tf.transpose(context_layer, perm=(0, 2, 1, 3)) - new_context_layer_shape = shape_list(context_layer)[:-2] + [self.all_head_size] - context_layer = tf.reshape(context_layer, new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - outputs = outputs + (past_key_value,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if self.is_cross_attention: - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.encoder_hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.encoder_hidden_size]) - else: - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -class TFBlipTextSelfOutput(keras.layers.Layer): - def __init__(self, config: BlipTextConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool | None = None) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#242 -class TFBlipTextAttention(keras.layers.Layer): - def __init__(self, config, is_cross_attention=False, **kwargs): - super().__init__(**kwargs) - self.self = TFBlipTextSelfAttention(config, is_cross_attention, name="self") - # "output" is a protected attribute on TF models - self.self_output = TFBlipTextSelfOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - output_attentions: bool | None = False, - training: bool | None = None, - ): - self_outputs = self.self( - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - training=training, - ) - attention_output = self.self_output(self_outputs[0], hidden_states, training=training) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self", None) is not None: - with tf.name_scope(self.self.name): - self.self.build(None) - if getattr(self, "self_output", None) is not None: - with tf.name_scope(self.self_output.name): - self.self_output.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->BlipText -class TFBlipTextIntermediate(keras.layers.Layer): - def __init__(self, config: BlipTextConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFBlipTextOutput(keras.layers.Layer): - def __init__(self, config: BlipTextConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFBlipTextLayer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.config = config - self.attention = TFBlipTextAttention(config, name="attention") - if self.config.is_decoder: - self.crossattention = TFBlipTextAttention( - config, is_cross_attention=self.config.is_decoder, name="crossattention" - ) - self.intermediate = TFBlipTextIntermediate(config, name="intermediate") - self.self_output = TFBlipTextOutput(config, name="output") - - def call( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - training=None, - ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - hidden_states, - attention_mask, - head_mask, - output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, - training=training, - ) - attention_output = self_attention_outputs[0] - - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - - if encoder_hidden_states is not None: - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions=output_attentions, - training=training, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - intermediate_output = self.intermediate(attention_output) - layer_output = self.self_output(intermediate_output, attention_output, training=training) - outputs = (layer_output,) + outputs - - outputs = outputs + (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "self_output", None) is not None: - with tf.name_scope(self.self_output.name): - self.self_output.build(None) - if getattr(self, "crossattention", None) is not None: - with tf.name_scope(self.crossattention.name): - self.crossattention.build(None) - - -# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L386 -@keras_serializable -class TFBlipTextEncoder(keras.layers.Layer): - config_class = BlipTextConfig - - def __init__(self, config, name=None, **kwargs): - super().__init__(name=name, **kwargs) - self.config = config - self.layer = [TFBlipTextLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - @unpack_inputs - def call( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - training=None, - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.is_decoder else None - - next_decoder_cache = () if use_cache else None - - for i in range(self.config.num_hidden_layers): - layer_module = self.layer[i] - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None - - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - training=training, - ) - - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->BlipText -class TFBlipTextPooler(keras.layers.Layer): - def __init__(self, config: BlipTextConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->BlipText -class TFBlipTextPredictionHeadTransform(keras.layers.Layer): - def __init__(self, config: BlipTextConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - - if isinstance(config.hidden_act, str): - self.transform_act_fn = get_tf_activation(config.hidden_act) - else: - self.transform_act_fn = config.hidden_act - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(inputs=hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFBlipTextLMPredictionHead(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.transform = TFBlipTextPredictionHeadTransform(config, name="transform") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = keras.layers.Dense( - config.vocab_size, - kernel_initializer=get_initializer(config.initializer_range), - name="decoder", - use_bias=False, - ) - self.config = config - - def build(self, input_shape=None): - self.bias = self.add_weight(name="bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True) - - if self.built: - return - self.built = True - if getattr(self, "transform", None) is not None: - with tf.name_scope(self.transform.name): - self.transform.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build([None, None, self.config.hidden_size]) - - def call(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) + self.bias - return hidden_states - - -class TFBlipTextOnlyMLMHead(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.predictions = TFBlipTextLMPredictionHead(config, name="predictions") - - def call(self, sequence_output: tf.Tensor) -> tf.Tensor: - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - -# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L548 -class TFBlipTextPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BlipTextConfig - base_model_prefix = "bert" - _keys_to_ignore_on_load_missing = [r"position_ids"] - - -# Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571 -class TFBlipTextModel(TFBlipTextPreTrainedModel): - """ - The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in [Attention is - all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, - Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and `is_decoder` set to `True`; an - `encoder_hidden_states` is then expected as an input to the forward pass. - """ - - def __init__(self, config, add_pooling_layer=True, name=None, **kwargs): - super().__init__(config, name=name, **kwargs) - self.config = config - - self.embeddings = TFBlipTextEmbeddings(config, name="embeddings") - self.encoder = TFBlipTextEncoder(config, name="encoder") - self.pooler = TFBlipTextPooler(config, name="pooler") if add_pooling_layer else None - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - @tf.function - def get_extended_attention_mask( - self, attention_mask: tf.Tensor, input_shape: tuple[int], is_decoder: bool - ) -> tf.Tensor: - """ - Makes broadcastable attention and causal masks so that future and masked tokens are ignored. - - Arguments: - attention_mask (`tf.Tensor`): - Mask with ones indicating tokens to attend to, zeros for tokens to ignore. - input_shape (`tuple[int]`): - The shape of the input to the model. - is_decoder (`bool`): - Whether the model is used as a decoder. - - Returns: - `tf.Tensor` The extended attention mask, with the same dtype as `attention_mask.dtype`. - """ - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if not isinstance(attention_mask, tf.Tensor): - attention_mask = tf.convert_to_tensor(attention_mask) # Catches NumPy inputs that haven't been cast yet - if attention_mask.shape.rank == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.shape.rank == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if is_decoder: - batch_size, seq_length = input_shape - - seq_ids = tf.range(seq_length, dtype=attention_mask.dtype) - causal_mask = tf.broadcast_to(seq_ids, (batch_size, seq_length, seq_length)) <= seq_ids[None, :, None] - # in case past_key_values are used we need to add a prefix ones mask to the causal mask - - if shape_list(causal_mask)[1] < shape_list(attention_mask)[1]: - prefix_seq_len = tf.shape(attention_mask)[1] - tf.shape(causal_mask)[1] - causal_mask = tf.concat( - [ - tf.ones((batch_size, seq_length, prefix_seq_len), dtype=causal_mask.dtype), - causal_mask, - ], - axis=-1, - ) - extended_attention_mask = ( - tf.cast(causal_mask[:, None, :, :], attention_mask.dtype) * attention_mask[:, None, None, :] - ) - else: - extended_attention_mask = attention_mask[:, None, None, :] - else: - raise ValueError( - f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, self.dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - return extended_attention_mask - - @add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING) - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - encoder_embeds: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - is_decoder: bool = False, - training: bool = False, - ) -> tuple[tf.Tensor] | TFBaseModelOutputWithPoolingAndCrossAttentions: - r""" - encoder_hidden_states (`tf.Tensor`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(tf.Tensor))`, *optional*): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - batch_size, seq_length = input_shape - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - batch_size, seq_length = input_shape - elif encoder_embeds is not None: - input_shape = shape_list(encoder_embeds)[:-1] - batch_size, seq_length = input_shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = tf.ones((batch_size, seq_length + past_key_values_length)) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: tf.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, is_decoder) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if encoder_hidden_states is not None: - if isinstance(encoder_hidden_states, list): - encoder_batch_size, encoder_sequence_length, _ = shape_list(encoder_hidden_states[0]) - else: - encoder_batch_size, encoder_sequence_length, _ = shape_list(encoder_hidden_states) - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - - if isinstance(encoder_attention_mask, list): - encoder_extended_attention_mask = [invert_attention_mask(mask) for mask in encoder_attention_mask] - elif encoder_attention_mask is None: - encoder_attention_mask = tf.ones(encoder_hidden_shape) - encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - if encoder_embeds is None: - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - else: - embedding_output = encoder_embeds - - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return TFBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811 -class TFBlipTextLMHeadModel(TFBlipTextPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - - self.bert = TFBlipTextModel(config, add_pooling_layer=False, name="bert") - self.cls = TFBlipTextOnlyMLMHead(config, name="cls") - self.label_smoothing = config.label_smoothing - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - @add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING) - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - return_logits=False, - is_decoder=True, - training=None, - ): - r""" - encoder_hidden_states (`tf.Tensor`, *optional*): Sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is - configured as a decoder. - encoder_attention_mask (`tf.Tensor`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`tf.Tensor`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` - past_key_values (`tuple(tuple(tf.Tensor))`, *optional*): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - is_decoder=is_decoder, - training=training, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - if return_logits: - return prediction_scores[:, :-1, :] - - lm_loss = None - if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[:, :-1, :] - shifted_prediction_scores = tf.reshape(shifted_prediction_scores, (-1, self.config.vocab_size)) - labels = labels[:, 1:] - labels = tf.reshape(labels, (-1,)) - # Keras won't give us label smoothing for sparse CE, so we de-sparsify things here - # Use relu to clamp masked labels at 0 to avoid NaN (we will be zeroing those out later anyway) - one_hot_labels = tf.one_hot(tf.nn.relu(labels), depth=self.config.vocab_size, dtype=tf.float32) - loss_fct = keras.losses.CategoricalCrossentropy( - from_logits=True, label_smoothing=self.label_smoothing, reduction="none" - ) - masked_positions = tf.cast(tf.not_equal(labels, -100), dtype=tf.float32) - lm_loss = loss_fct(one_hot_labels, shifted_prediction_scores) - lm_loss *= masked_positions - lm_loss = tf.reduce_sum(lm_loss, axis=0) / tf.math.count_nonzero(masked_positions, dtype=tf.float32) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return TFCausalLMOutputWithCrossAttentions( - loss=lm_loss, - logits=prediction_scores, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "encoder_hidden_states": model_kwargs.get("encoder_hidden_states"), - "encoder_attention_mask": model_kwargs.get("encoder_attention_mask"), - "is_decoder": True, - } - - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bert", None) is not None: - with tf.name_scope(self.bert.name): - self.bert.build(None) - if getattr(self, "cls", None) is not None: - with tf.name_scope(self.cls.name): - self.cls.build(None) - - -__all__ = ["TFBlipTextLMHeadModel", "TFBlipTextModel", "TFBlipTextPreTrainedModel"] diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py deleted file mode 100644 index c7bb1cc9c9a5..000000000000 --- a/src/transformers/models/bloom/modeling_flax_bloom.py +++ /dev/null @@ -1,737 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc. Team and Bigscience Workshop. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax BLOOM model.""" - -import math -from functools import partial -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask -from flax.linen.activation import tanh -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutput, -) -from ...modeling_flax_utils import FlaxPreTrainedModel, append_call_sample_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_bloom import BloomConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "bigscience/bloom" -_CONFIG_FOR_DOC = "BloomConfig" - - -BLOOM_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`BloomConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -BLOOM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def build_alibi_tensor(attention_mask: jnp.ndarray, num_heads: int, dtype: Optional[jnp.dtype] = jnp.float32): - """ - Flax implementation of the BLOOM Alibi tensor. BLOOM Alibi tensor is not causal as the original paper mentions, it - relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value - `softmax(l+a) = softmax(l)`. Based on - https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 - Link to paper: https://huggingface.co/papers/2108.12409 - - Args: - attention_mask (`jnp.ndarray`): - Token-wise attention mask, this should be of shape `(batch_size, max_seq_len)`. - num_heads (`int`): - Number of attention heads. - dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): - The data type (dtype) of the output tensor. - - Returns: Alibi tensor of shape `(batch_size * num_heads, 1, max_seq_len)`. - """ - batch_size, seq_length = attention_mask.shape - closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) - base = jnp.array(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=jnp.float32) - powers = jnp.arange(1, 1 + closest_power_of_2, dtype=jnp.float32) - slopes = jax.lax.pow(base, powers) - - if closest_power_of_2 != num_heads: - extra_base = jnp.array(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=jnp.float32) - num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = jnp.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=jnp.float32) - slopes = jnp.cat([slopes, jax.lax.pow(extra_base, extra_powers)], axis=0) - - # Note: the Alibi tensor will added to the attention bias that will be applied to the query, key product of attention - # therefore, Alibi will have to be of shape (batch_size, num_heads, query_length, key_length) - # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) - # so that the query_length dimension will then be broadcast correctly. - # This is more or less identical to T5's relative position bias: - # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 - arange_tensor = ((attention_mask.cumsum(axis=-1) - 1) * attention_mask)[:, None, :] - alibi = slopes[..., None] * arange_tensor - alibi = jnp.expand_dims(alibi, axis=2) - return jnp.asarray(alibi, dtype) - - -class FlaxBloomAttention(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.hidden_size = self.config.hidden_size - self.num_heads = self.config.n_head - self.head_dim = self.hidden_size // self.num_heads - self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 - - if self.head_dim * self.num_heads != self.hidden_size: - raise ValueError( - f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and " - f"`num_heads`: {self.num_heads})." - ) - - dense = partial( - nn.Dense, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - self.query_key_value = dense(self.hidden_size * 3) - self.dense = dense(self.hidden_size) - self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, self.head_dim * 3)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) - - @nn.compact - # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJAttention._concatenate_to_cache - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key - # positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - residual, - alibi, - attention_mask=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - batch_size, seq_length = hidden_states.shape[:2] - - # proj q, k, v - fused_qkv = self.query_key_value(hidden_states) - fused_qkv = self._split_heads(fused_qkv) - query, key, value = jnp.split(fused_qkv, 3, axis=-1) - - causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") - - # for fast decoding causal attention mask should be shifted - causal_attention_mask_shift = ( - self.variables["cache"]["cache_index"] if self.has_variable("cache", "cached_key") else 0 - ) - - # fast decoding for generate requires special attention_mask - if self.has_variable("cache", "cached_key"): - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_attention_mask = jax.lax.dynamic_slice( - causal_attention_mask, - (0, 0, causal_attention_mask_shift, 0), - (1, 1, seq_length, max_decoder_length), - ) - - # broadcast causal attention mask & attention mask to fit for merge - causal_attention_mask = jnp.broadcast_to( - causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] - ) - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape) - attention_mask = combine_masks(attention_mask, causal_attention_mask) - - dropout_rng = None - if not deterministic and self.config.attention_dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.has_variable("cache", "cached_key") or init_cache: - key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) - - # transform boolean mask into float mask - mask_value = jnp.finfo(self.dtype).min - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, mask_value).astype(self.dtype), - ) - - attention_bias = attention_bias + alibi - - # Cast in fp32 if the original dtype is different from fp32 - attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype - - attn_weights = dot_product_attention_weights( - query, - key, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_dropout, - deterministic=deterministic, - dtype=attention_dtype, - ) - - # Cast back in the original dtype if the native dtype is not fp32 - if self.attention_softmax_in_fp32: - attn_weights = attn_weights.astype(self.dtype) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) - attn_output = self._merge_heads(attn_output) - attn_output = self.dense(attn_output) - attn_output = self.resid_dropout(attn_output, deterministic=deterministic) - - attn_output = attn_output + residual - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class BloomGELU(nn.Module): - def setup(self): - self.dtype = jnp.float32 - - def __call__(self, x): - return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x))) - - -class FlaxBloomMLP(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - hidden_size = self.config.hidden_size - - kernel_init = jax.nn.initializers.normal(self.config.initializer_range) - - self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init) - self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init) - self.hidden_dropout = nn.Dropout(self.config.hidden_dropout) - self.act = BloomGELU() - - def __call__(self, hidden_states, residual, deterministic: bool = True): - hidden_states = self.dense_h_to_4h(hidden_states) - hidden_states = self.act(hidden_states) - - intermediate_output = self.dense_4h_to_h(hidden_states) - - intermediate_output = intermediate_output + residual - hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic) - - return hidden_states - - -class FlaxBloomBlock(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - - self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype) - self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - - self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype) - - self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm - self.hidden_dropout = self.config.hidden_dropout - - def __call__( - self, - hidden_states, - alibi, - attention_mask=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - layernorm_output = self.input_layernorm(hidden_states) - - # layer norm before saving residual if config calls for it - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - # self-attention - attn_outputs = self.self_attention( - layernorm_output, - residual=residual, - alibi=alibi, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - - attention_output = attn_outputs[0] - - outputs = attn_outputs[1:] - - post_layernorm = self.post_attention_layernorm(attention_output) - - # set residual based on config - if self.apply_residual_connection_post_layernorm: - residual = post_layernorm - else: - residual = attention_output - - output = self.mlp(post_layernorm, residual, deterministic=deterministic) - - outputs = (output,) + outputs - - return outputs - - -class FlaxBloomPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BloomConfig - base_model_prefix = "transformer" - module_class: nn.Module = None - - def __init__( - self, - config: BloomConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length), dtype="i4") - attention_mask = jnp.ones_like(input_ids) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) - def __call__( - self, - input_ids, - attention_mask=None, - past_key_values: Optional[dict] = None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size, sequence_length = input_ids.shape - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # If past_key_values are passed then cache is already initialized a private flag init_cache has to be passed - # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be - # changed by FlaxBloomAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - not train, - False, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -class FlaxBloomBlockCollection(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.layers = [ - FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype) - for layer_number in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - alibi, - attention_mask=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for layer_number in range(self.config.num_hidden_layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = self.layers[layer_number]( - hidden_states, - alibi=alibi, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - # this contains possible `None` values - `FlaxBloomModule` will filter them out - outputs = (hidden_states, all_hidden_states, all_attentions) - - return outputs - - -class FlaxBloomModule(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.embed_dim = self.config.hidden_size - - # word embeddings (no positional embedding layer) - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.embed_dim, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - - # post-embedding layernorm - self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - - # transformer layers - self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype) - - # final layernorm - self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - - def __call__( - self, - input_ids=None, - attention_mask=None, - deterministic=True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - inputs_embeds = self.word_embeddings(input_ids) - # do post-embedding layernorm - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - - # build alibi depending on `attention_mask` - alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype) - - outputs = self.h( - hidden_states, - alibi=alibi, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - hidden_states = self.ln_f(hidden_states) - - if output_hidden_states: - all_hidden_states = outputs[1] + (hidden_states,) - outputs = (hidden_states, all_hidden_states) + outputs[2:] - else: - outputs = (hidden_states,) + outputs[1:] - - if not return_dict: - return tuple(v for v in [outputs[0], outputs[-1]] if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=outputs[1], - attentions=outputs[-1], - ) - - -@add_start_docstrings( - "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.", - BLOOM_START_DOCSTRING, -) -# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoModel with GPTNeo->Bloom -class FlaxBloomModel(FlaxBloomPreTrainedModel): - module_class = FlaxBloomModule - - -append_call_sample_docstring(FlaxBloomModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC) - - -class FlaxBloomForCausalLMModule(nn.Module): - config: BloomConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.transformer = FlaxBloomModule(self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - - def __call__( - self, - input_ids, - attention_mask, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T - lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - - -@add_start_docstrings( - """ - The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - BLOOM_START_DOCSTRING, -) -class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel): - module_class = FlaxBloomForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for - # x > input_ids.shape[-1] and x < cache_length. But since Bloom uses a causal mask, - # those positions are masked anyway. Thus, we can create a single static attention_mask here, - # which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - return model_kwargs - - -append_call_sample_docstring(FlaxBloomForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC) - - -__all__ = ["FlaxBloomForCausalLM", "FlaxBloomModel", "FlaxBloomPreTrainedModel"] diff --git a/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py deleted file mode 100755 index 9b1b15857cea..000000000000 --- a/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,59 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The T5 authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert T5 checkpoint.""" - -import argparse - -from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): - # Initialise PyTorch model - config = T5Config.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") - model = T5ForConditionalGeneration(config) - - # Load weights from tf checkpoint - load_tf_weights_in_t5(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/camembert/modeling_tf_camembert.py b/src/transformers/models/camembert/modeling_tf_camembert.py deleted file mode 100644 index 0869902aa962..000000000000 --- a/src/transformers/models/camembert/modeling_tf_camembert.py +++ /dev/null @@ -1,1800 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 CamemBERT model.""" - -from __future__ import annotations - -import math -import warnings - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFBaseModelOutputWithPoolingAndCrossAttentions, - TFCausalLMOutputWithCrossAttentions, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_camembert import CamembertConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "almanach/camembert-base" -_CONFIG_FOR_DOC = "CamembertConfig" - - -CAMEMBERT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`CamembertConfig`]): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -CAMEMBERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaEmbeddings -class TFCamembertEmbeddings(keras.layers.Layer): - """ - Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. - """ - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.padding_idx = 1 - self.config = config - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding - symbols are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - input_ids: tf.Tensor - Returns: tf.Tensor - """ - mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) - incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask - - return incremental_indices + self.padding_idx - - def call( - self, - input_ids=None, - position_ids=None, - token_type_ids=None, - inputs_embeds=None, - past_key_values_length=0, - training=False, - ): - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - if input_ids is not None: - # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = self.create_position_ids_from_input_ids( - input_ids=input_ids, past_key_values_length=past_key_values_length - ) - else: - position_ids = tf.expand_dims( - tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 - ) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Camembert -class TFCamembertPooler(keras.layers.Layer): - def __init__(self, config: CamembertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Camembert -class TFCamembertSelfAttention(keras.layers.Layer): - def __init__(self, config: CamembertConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - - self.is_decoder = config.is_decoder - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concat([past_key_value[0], key_layer], axis=2) - value_layer = tf.concat([past_key_value[1], value_layer], axis=2) - else: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFCamembertModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Camembert -class TFCamembertSelfOutput(keras.layers.Layer): - def __init__(self, config: CamembertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Camembert -class TFCamembertAttention(keras.layers.Layer): - def __init__(self, config: CamembertConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFCamembertSelfAttention(config, name="self") - self.dense_output = TFCamembertSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - # add attentions (possibly with past_key_value) if we output them - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Camembert -class TFCamembertIntermediate(keras.layers.Layer): - def __init__(self, config: CamembertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Camembert -class TFCamembertOutput(keras.layers.Layer): - def __init__(self, config: CamembertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Camembert -class TFCamembertLayer(keras.layers.Layer): - def __init__(self, config: CamembertConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFCamembertAttention(config, name="attention") - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TFCamembertAttention(config, name="crossattention") - self.intermediate = TFCamembertIntermediate(config, name="intermediate") - self.bert_output = TFCamembertOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_value: tuple[tf.Tensor] | None, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=self_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - input_tensor=attention_output, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + outputs # add attentions if we output them - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - if getattr(self, "crossattention", None) is not None: - with tf.name_scope(self.crossattention.name): - self.crossattention.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Camembert -class TFCamembertEncoder(keras.layers.Layer): - def __init__(self, config: CamembertConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layer = [TFCamembertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_values: tuple[tuple[tf.Tensor]] | None, - use_cache: bool | None, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - past_key_value = past_key_values[i] if past_key_values is not None else None - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - if self.config.add_cross_attention and encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None - ) - - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaMainLayer with Roberta->Camembert -class TFCamembertMainLayer(keras.layers.Layer): - config_class = CamembertConfig - - def __init__(self, config, add_pooling_layer=True, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.is_decoder = config.is_decoder - - self.num_hidden_layers = config.num_hidden_layers - self.initializer_range = config.initializer_range - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - self.encoder = TFCamembertEncoder(config, name="encoder") - self.pooler = TFCamembertPooler(config, name="pooler") if add_pooling_layer else None - # The embeddings must be the last declaration in order to follow the weights order - self.embeddings = TFCamembertEmbeddings(config, name="embeddings") - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - if not self.config.is_decoder: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_key_values_length = 0 - past_key_values = [None] * len(self.encoder.layer) - else: - past_key_values_length = shape_list(past_key_values[0][0])[-2] - - if attention_mask is None: - attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - training=training, - ) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(attention_mask) - - mask_seq_length = seq_length + past_key_values_length - # Copied from `modeling_tf_t5.py` - # Provided a padding mask of dimensions [batch_size, mask_seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if self.is_decoder: - seq_ids = tf.range(mask_seq_length) - causal_mask = tf.less_equal( - tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), - seq_ids[None, :, None], - ) - causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) - extended_attention_mask = causal_mask * attention_mask[:, None, :] - attention_mask_shape = shape_list(extended_attention_mask) - extended_attention_mask = tf.reshape( - extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) - ) - if past_key_values[0] is not None: - # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] - extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] - else: - extended_attention_mask = tf.reshape( - attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.is_decoder and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) - num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) - if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, - # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) - - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - - -class TFCamembertPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = CamembertConfig - base_model_prefix = "roberta" - - -@add_start_docstrings( - "The bare CamemBERT Model transformer outputting raw hidden-states without any specific head on top.", - CAMEMBERT_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaModel with Roberta->Camembert, ROBERTA->CAMEMBERT -class TFCamembertModel(TFCamembertPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.roberta = TFCamembertMainLayer(config, name="roberta") - - @unpack_inputs - @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutputWithPoolingAndCrossAttentions: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - """ - outputs = self.roberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->Camembert -class TFCamembertLMHead(keras.layers.Layer): - """Camembert Head for masked language modeling.""" - - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.act = get_tf_activation("gelu") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - - def get_output_embeddings(self): - return self.decoder - - def set_output_embeddings(self, value): - self.decoder.weight = value - self.decoder.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.layer_norm(hidden_states) - - # project back to size of vocabulary with bias - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -@add_start_docstrings( - """CamemBERT Model with a `language modeling` head on top.""", - CAMEMBERT_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMaskedLM with Roberta->Camembert, ROBERTA->CAMEMBERT -class TFCamembertForMaskedLM(TFCamembertPreTrainedModel, TFMaskedLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta") - self.lm_head = TFCamembertLMHead(config, self.roberta.embeddings, name="lm_head") - - def get_lm_head(self): - return self.lm_head - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - mask="", - expected_output="' Paris'", - expected_loss=0.1, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead -class TFCamembertClassificationHead(keras.layers.Layer): - """Head for sentence-level classification tasks.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.out_proj = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" - ) - self.config = config - - def call(self, features, training=False): - x = features[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x, training=training) - x = self.dense(x) - x = self.dropout(x, training=training) - x = self.out_proj(x) - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - CAMEMBERT_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForSequenceClassification with Roberta->Camembert, ROBERTA->CAMEMBERT -class TFCamembertForSequenceClassification(TFCamembertPreTrainedModel, TFSequenceClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta") - self.classifier = TFCamembertClassificationHead(config, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="cardiffnlp/twitter-roberta-base-emotion", - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="'optimism'", - expected_loss=0.08, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.classifier(sequence_output, training=training) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - CamemBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. - for Named-Entity-Recognition (NER) tasks. - """, - CAMEMBERT_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForTokenClassification with Roberta->Camembert, ROBERTA->CAMEMBERT -class TFCamembertForTokenClassification(TFCamembertPreTrainedModel, TFTokenClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta") - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="ydshieh/roberta-large-ner-english", - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']", - expected_loss=0.01, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - CamemBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - CAMEMBERT_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMultipleChoice with Roberta->Camembert, ROBERTA->CAMEMBERT -class TFCamembertForMultipleChoice(TFCamembertPreTrainedModel, TFMultipleChoiceLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"lm_head"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.roberta = TFCamembertMainLayer(config, name="roberta") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward( - CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") - ) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - outputs = self.roberta( - flat_input_ids, - flat_attention_mask, - flat_token_type_ids, - flat_position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, training=training) - logits = self.classifier(pooled_output) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - CAMEMBERT_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForQuestionAnswering with Roberta->Camembert, ROBERTA->CAMEMBERT -class TFCamembertForQuestionAnswering(TFCamembertPreTrainedModel, TFQuestionAnsweringLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="ydshieh/roberta-base-squad2", - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="' puppet'", - expected_loss=0.86, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - - loss = None - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """CamemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", CAMEMBERT_START_DOCSTRING -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForCausalLM with Roberta->Camembert, ROBERTA->CAMEMBERT -class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] - - def __init__(self, config: CamembertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if not config.is_decoder: - logger.warning("If you want to use `TFCamembertLMHeadModel` as a standalone, add `is_decoder=True.`") - - self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta") - self.lm_head = TFCamembertLMHead(config, input_embeddings=self.roberta.embeddings, name="lm_head") - - def get_lm_head(self): - return self.lm_head - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = tf.ones(input_shape) - - # cut decoder_input_ids if past is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - - @unpack_inputs - @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - outputs = self.roberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - logits = self.lm_head(hidden_states=sequence_output, training=training) - loss = None - - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithCrossAttentions( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -__all__ = [ - "TFCamembertForCausalLM", - "TFCamembertForMaskedLM", - "TFCamembertForMultipleChoice", - "TFCamembertForQuestionAnswering", - "TFCamembertForSequenceClassification", - "TFCamembertForTokenClassification", - "TFCamembertModel", - "TFCamembertPreTrainedModel", -] diff --git a/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index 45dcdb290333..000000000000 --- a/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,65 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert CANINE checkpoint.""" - -import argparse - -from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path): - # Initialize PyTorch model - config = CanineConfig() - model = CanineModel(config) - model.eval() - - print(f"Building PyTorch model from configuration: {config}") - - # Load weights from tf checkpoint - load_tf_weights_in_canine(model, config, tf_checkpoint_path) - - # Save pytorch-model (weights and configuration) - print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - # Save tokenizer files - tokenizer = CanineTokenizer() - print(f"Save tokenizer files to {pytorch_dump_path}") - tokenizer.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", - default=None, - type=str, - required=True, - help="Path to the TensorFlow checkpoint. Should end with model.ckpt", - ) - parser.add_argument( - "--pytorch_dump_path", - default=None, - type=str, - required=True, - help="Path to a folder where the PyTorch model will be placed.", - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.pytorch_dump_path) diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py deleted file mode 100644 index 0394974d0647..000000000000 --- a/src/transformers/models/clip/modeling_flax_clip.py +++ /dev/null @@ -1,1306 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The OpenAI Team Authors, The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Optional, Union - -import flax -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import ModelOutput, add_start_docstrings, logging -from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig - - -logger = logging.get_logger(__name__) - -CLIP_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -CLIP_TEXT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -CLIP_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -CLIP_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@flax.struct.dataclass -class FlaxCLIPTextModelOutput(ModelOutput): - """ - Base class for text model's outputs that also contains a pooling of the last hidden states. - - Args: - text_embeds (`jnp.ndarray` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of - [`FlaxCLIPTextModel`]. - last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - text_embeds: jnp.ndarray = None - last_hidden_state: jnp.ndarray = None - hidden_states: Optional[tuple[jnp.ndarray, ...]] = None - attentions: Optional[tuple[jnp.ndarray, ...]] = None - - -@flax.struct.dataclass -class FlaxCLIPOutput(ModelOutput): - """ - Args: - logits_per_image:(`jnp.ndarray` of shape `(image_batch_size, text_batch_size)`): - The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text - similarity scores. - logits_per_text:(`jnp.ndarray` of shape `(text_batch_size, image_batch_size)`): - The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image - similarity scores. - text_embeds(`jnp.ndarray` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of - [`FlaxCLIPTextModel`]. - image_embeds(`jnp.ndarray` of shape `(batch_size, output_dim`): - The image embeddings obtained by applying the projection layer to the pooled output of - [`FlaxCLIPVisionModel`]. - text_model_output(`FlaxBaseModelOutputWithPooling`): - The output of the [`FlaxCLIPTextModel`]. - vision_model_output(`FlaxBaseModelOutputWithPooling`): - The output of the [`FlaxCLIPVisionModel`]. - """ - - logits_per_image: jnp.ndarray = None - logits_per_text: jnp.ndarray = None - text_embeds: jnp.ndarray = None - image_embeds: jnp.ndarray = None - text_model_output: FlaxBaseModelOutputWithPooling = None - vision_model_output: FlaxBaseModelOutputWithPooling = None - - def to_tuple(self) -> tuple[Any]: - return tuple( - self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() - for k in self.keys() - ) - - -class FlaxCLIPVisionEmbeddings(nn.Module): - config: CLIPVisionConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - embed_dim = self.config.hidden_size - image_size = self.config.image_size - patch_size = self.config.patch_size - - self.class_embedding = self.param("class_embedding", jax.nn.initializers.normal(stddev=0.02), (embed_dim,)) - - self.patch_embedding = nn.Conv( - embed_dim, - kernel_size=(patch_size, patch_size), - strides=(patch_size, patch_size), - padding="VALID", - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(), - ) - - self.num_patches = (image_size // patch_size) ** 2 - num_positions = self.num_patches + 1 - self.position_embedding = nn.Embed(num_positions, embed_dim, embedding_init=jax.nn.initializers.normal()) - self.position_ids = jnp.expand_dims(jnp.arange(0, num_positions, dtype="i4"), axis=0) - - def __call__(self, pixel_values): - patch_embeds = self.patch_embedding(pixel_values) - batch_size, height, width, channels = patch_embeds.shape - patch_embeds = jnp.reshape(patch_embeds, (batch_size, height * width, channels)) - - class_embeds = jnp.expand_dims(self.class_embedding, axis=(0, 1)) - class_embeds = jnp.tile(class_embeds, (batch_size, 1, 1)) - embeddings = jnp.concatenate([class_embeds, patch_embeds], axis=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -class FlaxCLIPTextEmbeddings(nn.Module): - config: CLIPTextConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - embed_dim = self.config.hidden_size - - self.token_embedding = nn.Embed(self.config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal()) - self.position_embedding = nn.Embed( - self.config.max_position_embeddings, embed_dim, embedding_init=jax.nn.initializers.normal() - ) - self.position_ids = jnp.expand_dims( - jnp.arange(0, self.config.max_position_embeddings, dtype="i4"), axis=(0, 1) - ) - - def __call__(self, input_ids, position_ids): - input_embeds = self.token_embedding(input_ids.astype("i4")) - position_embeds = self.position_embedding(position_ids.astype("i4")) - - embeddings = input_embeds + position_embeds - return embeddings - - -class FlaxCLIPAttention(nn.Module): - config: Union[CLIPTextConfig, CLIPVisionConfig] - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.embed_dim = self.config.hidden_size - self.num_heads = self.config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - self.dropout = self.config.attention_dropout - - self.k_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) - self.v_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) - self.q_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) - self.out_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) - - self.causal = isinstance(self.config, CLIPTextConfig) - if self.causal: - self.causal_mask = make_causal_mask(jnp.ones((1, self.config.max_position_embeddings), dtype="i4")) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - def __call__( - self, - hidden_states, - attention_mask=None, - deterministic: bool = True, - output_attentions: bool = False, - ): - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - query = self._split_heads(query) - key = self._split_heads(key) - value = self._split_heads(value) - - causal_attention_mask = None - if self.causal: - query_length, key_length = query.shape[1], key.shape[1] - causal_attention_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length] - - if attention_mask is not None and causal_attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - attention_mask = combine_masks(attention_mask, causal_attention_mask, dtype="i4") - elif causal_attention_mask is not None: - attention_mask = causal_attention_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - if attention_mask is not None: - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query, - key, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxCLIPMLP(nn.Module): - config: Union[CLIPTextConfig, CLIPVisionConfig] - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.activation_fn = ACT2FN[self.config.hidden_act] - self.fc1 = nn.Dense( - self.config.intermediate_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(0.01), - ) - self.fc2 = nn.Dense(self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) - - def __call__(self, hidden_states): - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -class FlaxCLIPEncoderLayer(nn.Module): - config: Union[CLIPTextConfig, CLIPVisionConfig] - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.self_attn = FlaxCLIPAttention(self.config, dtype=self.dtype) - self.layer_norm1 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.mlp = FlaxCLIPMLP(self.config, dtype=self.dtype) - self.layer_norm2 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - ): - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - attn_outputs = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - ) - hidden_states = attn_outputs[0] - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += attn_outputs[1:] - - return outputs - - -class FlaxCLIPLayerCollection(nn.Module): - config: Union[CLIPTextConfig, CLIPVisionConfig] - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.layers = [ - FlaxCLIPEncoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask=None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer( - hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states,) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -class FlaxCLIPEncoder(nn.Module): - config: Union[CLIPTextConfig, CLIPVisionConfig] - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.layers = FlaxCLIPLayerCollection(self.config, dtype=self.dtype) - - def __call__( - self, - inputs_embeds, - attention_mask=None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.layers( - hidden_states=inputs_embeds, - attention_mask=attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class FlaxCLIPTextTransformer(nn.Module): - config: CLIPTextConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.embeddings = FlaxCLIPTextEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype) - self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - # For `pooled_output` computation - self.eos_token_id = self.config.eos_token_id - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - attention_mask=attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - last_hidden_state = self.final_layer_norm(last_hidden_state) - - if self.eos_token_id == 2: - # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. - # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added - # ------------------------------------------------------------ - # text_embeds.shape = [batch_size, sequence_length, transformer.width] - # take features from the EOS embedding (eos_token_id is the highest number in each sequence) - pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)] - else: - # (no need to cast from bool to int after comparing to `eos_token_id`) - pooled_output = last_hidden_state[ - jnp.arange(last_hidden_state.shape[0]), (input_ids == self.eos_token_id).argmax(axis=-1) - ] - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return FlaxBaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class FlaxCLIPVisionTransformer(nn.Module): - config: CLIPVisionConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.embeddings = FlaxCLIPVisionEmbeddings(self.config, dtype=self.dtype) - self.pre_layrnorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype) - self.post_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__( - self, - pixel_values=None, - deterministic: bool = True, - output_attentions=None, - output_hidden_states=None, - return_dict: bool = True, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - hidden_states = self.embeddings(pixel_values) - hidden_states = self.pre_layrnorm(hidden_states) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - pooled_output = last_hidden_state[:, 0, :] - pooled_output = self.post_layernorm(pooled_output) - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return FlaxBaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel): - config_class = CLIPTextConfig - module_class: nn.Module = None - - def __init__( - self, - config: CLIPTextConfig, - input_shape=(1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensor - input_ids = jnp.zeros(input_shape, dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - attention_mask = jnp.ones_like(input_ids) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, input_ids, attention_mask, position_ids)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) - - -class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel): - config_class = CLIPVisionConfig - main_input_name = "pixel_values" - module_class: nn.Module = None - - def __init__( - self, - config: CLIPVisionConfig, - input_shape: Optional[tuple] = None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - if input_shape is None: - input_shape = (1, config.image_size, config.image_size, 3) - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensor - pixel_values = jax.random.normal(rng, input_shape) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, pixel_values)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def __call__( - self, - pixel_values, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - return self.module.apply( - {"params": params or self.params}, - jnp.array(pixel_values, dtype=jnp.float32), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) - - -class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): - config_class = CLIPConfig - module_class: nn.Module = None - - def __init__( - self, - config: CLIPConfig, - input_shape: Optional[tuple] = None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - if input_shape is None: - input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3)) - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensor - input_ids = jnp.zeros(input_shape[0], dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]) - attention_mask = jnp.ones_like(input_ids) - - pixel_values = jax.random.normal(rng, input_shape[1]) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def __call__( - self, - input_ids, - pixel_values, - attention_mask=None, - position_ids=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(pixel_values, dtype=jnp.float32), - jnp.array(attention_mask, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) - - def get_text_features( - self, - input_ids, - attention_mask=None, - position_ids=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train=False, - ): - r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - - Returns: - text_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The text embeddings obtained by applying - the projection layer to the pooled output of [`FlaxCLIPTextModel`]. - - Examples: - - ```python - >>> from transformers import AutoTokenizer, FlaxCLIPModel - - >>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") - >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") - - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np") - >>> text_features = model.get_text_features(**inputs) - ```""" - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _get_features(module, input_ids, attention_mask, position_ids, deterministic): - text_outputs = module.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - ) - pooled_output = text_outputs[1] - text_features = module.text_projection(pooled_output) - return text_features - - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - not train, - method=_get_features, - rngs=rngs, - ) - - def get_image_features( - self, pixel_values, params: Optional[dict] = None, dropout_rng: jax.random.PRNGKey = None, train=False - ): - r""" - Args: - pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained - using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. - - Returns: - image_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The image embeddings obtained by - applying the projection layer to the pooled output of [`FlaxCLIPVisionModel`] - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, FlaxCLIPModel - - >>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") - >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="np") - - >>> image_features = model.get_image_features(**inputs) - ```""" - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _get_features(module, pixel_values, deterministic): - vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic) - pooled_output = vision_outputs[1] # pooled_output - image_features = module.visual_projection(pooled_output) - return image_features - - return self.module.apply( - {"params": params or self.params}, - jnp.array(pixel_values, dtype=jnp.float32), - not train, - method=_get_features, - rngs=rngs, - ) - - -class FlaxCLIPTextModule(nn.Module): - config: CLIPTextConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class FlaxCLIPTextModel(FlaxCLIPTextPreTrainedModel): - module_class = FlaxCLIPTextModule - - -FLAX_CLIP_TEXT_MODEL_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxCLIPTextModel - - >>> model = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") - >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") - - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooler_output = outputs.pooler_output # pooled (EOS token) states - ``` -""" - -overwrite_call_docstring(FlaxCLIPTextModel, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_DOCSTRING) -append_replace_return_docstrings( - FlaxCLIPTextModel, output_type=FlaxBaseModelOutputWithPooling, config_class=CLIPTextConfig -) - - -class FlaxCLIPTextModelWithProjectionModule(nn.Module): - config: CLIPTextConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype) - self.text_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - text_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = text_outputs[1] - text_embeds = self.text_projection(pooled_output) - - if not return_dict: - return (text_embeds, text_outputs[0]) + text_outputs[2:] - - return FlaxCLIPTextModelOutput( - text_embeds=text_embeds, - last_hidden_state=text_outputs.last_hidden_state, - hidden_states=text_outputs.hidden_states, - attentions=text_outputs.attentions, - ) - - -class FlaxCLIPTextModelWithProjection(FlaxCLIPTextPreTrainedModel): - module_class = FlaxCLIPTextModelWithProjectionModule - - -FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxCLIPTextModelWithProjection - - >>> model = FlaxCLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") - >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") - - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np") - - >>> outputs = model(**inputs) - >>> text_embeds = outputs.text_embeds - ``` -""" - -overwrite_call_docstring( - FlaxCLIPTextModelWithProjection, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING -) -append_replace_return_docstrings( - FlaxCLIPTextModelWithProjection, output_type=FlaxCLIPTextModelOutput, config_class=CLIPTextConfig -) - - -class FlaxCLIPVisionModule(nn.Module): - config: CLIPVisionConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.vision_model = FlaxCLIPVisionTransformer(self.config, dtype=self.dtype) - - def __call__( - self, - pixel_values, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.vision_model( - pixel_values=pixel_values, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class FlaxCLIPVisionModel(FlaxCLIPVisionPreTrainedModel): - module_class = FlaxCLIPVisionModule - - -FLAX_CLIP_VISION_MODEL_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, FlaxCLIPVisionModel - - >>> model = FlaxCLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") - >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="np") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooler_output = outputs.pooler_output # pooled CLS states - ``` -""" - -overwrite_call_docstring(FlaxCLIPVisionModel, CLIP_VISION_INPUTS_DOCSTRING + FLAX_CLIP_VISION_MODEL_DOCSTRING) -append_replace_return_docstrings( - FlaxCLIPVisionModel, output_type=FlaxBaseModelOutputWithPooling, config_class=CLIPVisionConfig -) - - -class FlaxCLIPModule(nn.Module): - config: CLIPConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - text_config = self.config.text_config - vision_config = self.config.vision_config - - self.projection_dim = self.config.projection_dim - self.text_embed_dim = text_config.hidden_size - self.vision_embed_dim = vision_config.hidden_size - - self.text_model = FlaxCLIPTextTransformer(text_config, dtype=self.dtype) - self.vision_model = FlaxCLIPVisionTransformer(vision_config, dtype=self.dtype) - - self.visual_projection = nn.Dense( - self.projection_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(0.02), - use_bias=False, - ) - self.text_projection = nn.Dense( - self.projection_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(0.02), - use_bias=False, - ) - - self.logit_scale = self.param( - "logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, [] - ) - - def __call__( - self, - input_ids=None, - pixel_values=None, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - return_dict = return_dict if return_dict is not None else self.config.return_dict - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - text_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - image_embeds = vision_outputs[1] - image_embeds = self.visual_projection(image_embeds) - - text_embeds = text_outputs[1] - text_embeds = self.text_projection(text_embeds) - - # normalized features - image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True) - text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True) - - # cosine similarity as logits - logit_scale = jnp.exp(self.logit_scale) - logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale - logits_per_image = logits_per_text.T - - if not return_dict: - return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - - return FlaxCLIPOutput( - logits_per_image=logits_per_image, - logits_per_text=logits_per_text, - text_embeds=text_embeds, - image_embeds=image_embeds, - text_model_output=text_outputs, - vision_model_output=vision_outputs, - ) - - -@add_start_docstrings(CLIP_START_DOCSTRING) -class FlaxCLIPModel(FlaxCLIPPreTrainedModel): - module_class = FlaxCLIPModule - - -FLAX_CLIP_MODEL_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> import jax - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, FlaxCLIPModel - - >>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") - >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor( - ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="np", padding=True - ... ) - - >>> outputs = model(**inputs) - >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score - >>> probs = jax.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities - ``` -""" - -overwrite_call_docstring(FlaxCLIPModel, CLIP_INPUTS_DOCSTRING + FLAX_CLIP_MODEL_DOCSTRING) -append_replace_return_docstrings(FlaxCLIPModel, output_type=FlaxCLIPOutput, config_class=CLIPConfig) - - -__all__ = [ - "FlaxCLIPModel", - "FlaxCLIPPreTrainedModel", - "FlaxCLIPTextModel", - "FlaxCLIPTextPreTrainedModel", - "FlaxCLIPTextModelWithProjection", - "FlaxCLIPVisionModel", - "FlaxCLIPVisionPreTrainedModel", -] diff --git a/src/transformers/models/clip/modeling_tf_clip.py b/src/transformers/models/clip/modeling_tf_clip.py deleted file mode 100644 index ab2e38827998..000000000000 --- a/src/transformers/models/clip/modeling_tf_clip.py +++ /dev/null @@ -1,1460 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 CLIP model.""" - -from __future__ import annotations - -import math -from dataclasses import dataclass -from typing import Any - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling - -# Public API -from ...modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32" - - -LARGE_NEGATIVE = -1e8 - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -# contrastive loss function, adapted from -# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html -def contrastive_loss(logits: tf.Tensor) -> tf.Tensor: - return tf.math.reduce_mean( - keras.metrics.sparse_categorical_crossentropy( - y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True - ) - ) - - -def clip_loss(similarity: tf.Tensor) -> tf.Tensor: - caption_loss = contrastive_loss(similarity) - image_loss = contrastive_loss(tf.transpose(similarity)) - return (caption_loss + image_loss) / 2.0 - - -@dataclass -class TFCLIPOutput(ModelOutput): - """ - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): - Contrastive loss for image-text similarity. - logits_per_image:(`tf.Tensor` of shape `(image_batch_size, text_batch_size)`): - The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text - similarity scores. - logits_per_text:(`tf.Tensor` of shape `(text_batch_size, image_batch_size)`): - The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image - similarity scores. - text_embeds(`tf.Tensor` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of [`TFCLIPTextModel`]. - image_embeds(`tf.Tensor` of shape `(batch_size, output_dim`): - The image embeddings obtained by applying the projection layer to the pooled output of - [`TFCLIPVisionModel`]. - text_model_output([`~modeling_tf_utils.TFBaseModelOutputWithPooling`]): - The output of the [`TFCLIPTextModel`]. - vision_model_output([`~modeling_tf_utils.TFBaseModelOutputWithPooling`]): - The output of the [`TFCLIPVisionModel`]. - """ - - loss: tf.Tensor | None = None - logits_per_image: tf.Tensor | None = None - logits_per_text: tf.Tensor | None = None - text_embeds: tf.Tensor | None = None - image_embeds: tf.Tensor | None = None - text_model_output: TFBaseModelOutputWithPooling = None - vision_model_output: TFBaseModelOutputWithPooling = None - - def to_tuple(self) -> tuple[Any]: - return tuple( - self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() - for k in self.keys() - ) - - -class TFCLIPVisionEmbeddings(keras.layers.Layer): - def __init__(self, config: CLIPVisionConfig, **kwargs): - super().__init__(**kwargs) - - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches + 1 - - self.config = config - - self.patch_embedding = keras.layers.Conv2D( - filters=self.embed_dim, - kernel_size=self.patch_size, - strides=self.patch_size, - padding="valid", - data_format="channels_last", - use_bias=False, - kernel_initializer=get_initializer(self.config.initializer_range * self.config.initializer_factor), - name="patch_embedding", - ) - - def build(self, input_shape: tf.TensorShape = None): - factor = self.config.initializer_factor - - self.class_embedding = self.add_weight( - shape=(self.embed_dim,), - initializer=get_initializer(self.embed_dim**-0.5 * factor), - trainable=True, - name="class_embedding", - ) - - with tf.name_scope("position_embedding"): - self.position_embedding = self.add_weight( - shape=(self.num_positions, self.embed_dim), - initializer=get_initializer(self.config.initializer_range * factor), - trainable=True, - name="embeddings", - ) - - if self.built: - return - self.built = True - if getattr(self, "patch_embedding", None) is not None: - with tf.name_scope(self.patch_embedding.name): - self.patch_embedding.build([None, None, None, self.config.num_channels]) - - def call(self, pixel_values: tf.Tensor) -> tf.Tensor: - """`pixel_values` is expected to be of NCHW format.""" - - batch_size, num_channels, height, width = shape_list(pixel_values) - - # When running on CPU, `tf.nn.conv2d` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - # shape = (batch_size, in_height, in_width, in_channels=num_channels) - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - - patch_embeds = self.patch_embedding(pixel_values) - - # Change the 2D spatial dimensions to a single temporal dimension. - # shape = (batch_size, num_patches, out_channels=embed_dim) - patch_embeds = tf.reshape(tensor=patch_embeds, shape=(batch_size, self.num_patches, -1)) - - # add the [CLS] token to the embedded patch tokens - class_embeds = tf.broadcast_to(self.class_embedding, shape=(batch_size, 1, self.embed_dim)) - embeddings = tf.concat((class_embeds, patch_embeds), axis=1) - - embeddings = embeddings + self.position_embedding - - return embeddings - - -class TFCLIPTextEmbeddings(keras.layers.Layer): - def __init__(self, config: CLIPTextConfig, **kwargs): - super().__init__(**kwargs) - - self.embed_dim = config.hidden_size - - self.config = config - - def build(self, input_shape: tf.TensorShape = None): - with tf.name_scope("token_embedding"): - self.weight = self.add_weight( - shape=(self.config.vocab_size, self.embed_dim), - initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range), - trainable=True, - name="weight", - ) - - with tf.name_scope("position_embedding"): - self.position_embedding = self.add_weight( - shape=(self.config.max_position_embeddings, self.embed_dim), - initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range), - trainable=True, - name="embeddings", - ) - - super().build(input_shape) - - def call( - self, - input_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - if input_ids is None and inputs_embeds is None: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) - - position_embeds = tf.gather(params=self.position_embedding, indices=position_ids) - position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) - final_embeddings = inputs_embeds + position_embeds - - return final_embeddings - - -class TFCLIPAttention(keras.layers.Layer): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: CLIPConfig, **kwargs): - super().__init__(**kwargs) - - self.embed_dim = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = self.embed_dim // self.num_attention_heads - if self.attention_head_size * self.num_attention_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_attention_heads})." - ) - - factor = config.initializer_factor - in_proj_std = (self.embed_dim**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor - out_proj_std = (self.embed_dim**-0.5) * factor - - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.q_proj = keras.layers.Dense( - units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="q_proj" - ) - self.k_proj = keras.layers.Dense( - units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="k_proj" - ) - self.v_proj = keras.layers.Dense( - units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="v_proj" - ) - - self.dropout = keras.layers.Dropout(rate=config.attention_dropout) - - self.out_proj = keras.layers.Dense( - units=self.embed_dim, kernel_initializer=get_initializer(out_proj_std), name="out_proj" - ) - - # copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention.transpose_for_scores - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - causal_attention_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - """Input shape: Batch x Time x Channel""" - - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.q_proj(inputs=hidden_states) - mixed_key_layer = self.k_proj(inputs=hidden_states) - mixed_value_layer = self.v_proj(inputs=hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - # apply the causal_attention_mask first - if causal_attention_mask is not None: - # Apply the causal attention mask (precomputed for all layers in TFCLIPModel call() function) - attention_scores = tf.add(attention_scores, causal_attention_mask) - - if attention_mask is not None: - # Apply the attention mask (precomputed for all layers in TFCLIPModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - _attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=_attention_probs, training=training) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, embed_dim) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.embed_dim)) - - attention_output = self.out_proj(attention_output, training=training) - # In TFBert, attention weights are returned after dropout. - # However, in CLIP, they are returned before dropout. - outputs = (attention_output, _attention_probs) if output_attentions else (attention_output,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -class TFCLIPMLP(keras.layers.Layer): - def __init__(self, config: CLIPConfig, **kwargs): - super().__init__(**kwargs) - - self.activation_fn = get_tf_activation(config.hidden_act) - - factor = config.initializer_factor - in_proj_std = (config.hidden_size**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor - fc_std = (2 * config.hidden_size) ** -0.5 * factor - - self.fc1 = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(fc_std), name="fc1" - ) - self.fc2 = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(in_proj_std), name="fc2" - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.fc1(inputs=hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(inputs=hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.config.hidden_size]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.intermediate_size]) - - -class TFCLIPEncoderLayer(keras.layers.Layer): - def __init__(self, config: CLIPConfig, **kwargs): - super().__init__(**kwargs) - - self.embed_dim = config.hidden_size - self.self_attn = TFCLIPAttention(config, name="self_attn") - self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") - self.mlp = TFCLIPMLP(config, name="mlp") - self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - causal_attention_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - causal_attention_mask (`tf.Tensor`): causal attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`): - Whether or not to return the attentions tensors of all attention layers. See `outputs` under returned - tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(inputs=hidden_states) - attention_outputs = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - training=training, - ) - hidden_states = attention_outputs[0] - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(inputs=hidden_states) - hidden_states = self.mlp(hidden_states=hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) + attention_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "layer_norm1", None) is not None: - with tf.name_scope(self.layer_norm1.name): - self.layer_norm1.build([None, None, self.embed_dim]) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "layer_norm2", None) is not None: - with tf.name_scope(self.layer_norm2.name): - self.layer_norm2.build([None, None, self.embed_dim]) - - -class TFCLIPEncoder(keras.layers.Layer): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`TFCLIPEncoderLayer`]. - - Args: - config: CLIPConfig - """ - - def __init__(self, config: CLIPConfig, **kwargs): - super().__init__(**kwargs) - - self.layers = [TFCLIPEncoderLayer(config, name=f"layers_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - causal_attention_mask: tf.Tensor, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - for i, layer_module in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFCLIPTextTransformer(keras.layers.Layer): - def __init__(self, config: CLIPTextConfig, **kwargs): - super().__init__(**kwargs) - - self.embeddings = TFCLIPTextEmbeddings(config, name="embeddings") - self.encoder = TFCLIPEncoder(config, name="encoder") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm") - - # For `pooled_output` computation - self.eos_token_id = config.eos_token_id - self.embed_dim = config.hidden_size - - def call( - self, - input_ids: TFModelInputType, - attention_mask: tf.Tensor, - position_ids: tf.Tensor, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - input_shape = shape_list(input_ids) - - embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids) - - batch_size, seq_length = input_shape - # CLIP's text model uses causal mask, prepare it here. - # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = self._build_causal_attention_mask(batch_size, seq_length, dtype=embedding_output.dtype) - - # check attention mask and invert - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask) - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - sequence_output = self.final_layer_norm(inputs=sequence_output) - - if self.eos_token_id == 2: - # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. - # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added - # ------------------------------------------------------------ - # text_embeds.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - pooled_output = tf.gather_nd( - params=sequence_output, - indices=tf.stack( - values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1 - ), - ) - else: - # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible) - pooled_output = tf.gather_nd( - params=sequence_output, - indices=tf.stack( - values=( - tf.range(input_shape[0], dtype=tf.int64), - tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1), - ), - axis=1, - ), - ) - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def _build_causal_attention_mask(self, batch_size, seq_length, dtype=tf.float32): - # It is possible with an unspecified sequence length for seq_length to be - # a runtime value, which is unsupported by tf.constant. Per the TensorFlow - # docs, tf.fill can handle runtime dynamic shapes: - # https://www.tensorflow.org/api_docs/python/tf/fill - diag = tf.cast(tf.fill((seq_length,), 0.0), dtype) - - # set an additive 2D attention mask with all places being masked - to_mask = tf.cast(tf.fill((seq_length, seq_length), -10000.0), dtype) - - # set diagonal & lower triangular parts to 0 (i.e. the places not to be masked) - # TIP: think the 2D matrix as the space of (query_seq, key_seq) - to_mask = tf.linalg.band_part(to_mask, 0, -1) - # to_mask = tf.linalg.band_part(to_mask, -1, 0) - to_mask = tf.linalg.set_diag(to_mask, diagonal=diag) - - return tf.broadcast_to(input=to_mask, shape=(batch_size, 1, seq_length, seq_length)) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -@keras_serializable -class TFCLIPTextMainLayer(keras.layers.Layer): - config_class = CLIPTextConfig - - def __init__(self, config: CLIPTextConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.text_model = TFCLIPTextTransformer(config, name="text_model") - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.text_model.embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.text_model.embeddings.weight = value - self.text_model.embeddings.vocab_size = shape_list(value)[0] - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - if input_ids is None: - raise ValueError("You have to specify input_ids") - - input_shape = shape_list(input_ids) - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - - text_model_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return text_model_outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "text_model", None) is not None: - with tf.name_scope(self.text_model.name): - self.text_model.build(None) - - -class TFCLIPVisionTransformer(keras.layers.Layer): - def __init__(self, config: CLIPVisionConfig, **kwargs): - super().__init__(**kwargs) - - self.embeddings = TFCLIPVisionEmbeddings(config, name="embeddings") - self.pre_layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="pre_layrnorm") - self.encoder = TFCLIPEncoder(config, name="encoder") - self.post_layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm") - self.embed_dim = config.hidden_size - - def call( - self, - pixel_values: TFModelInputType, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - embedding_output = self.embeddings(pixel_values=pixel_values) - embedding_output = self.pre_layernorm(inputs=embedding_output) - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=None, - causal_attention_mask=None, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = sequence_output[:, 0, :] - pooled_output = self.post_layernorm(inputs=pooled_output) - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "pre_layernorm", None) is not None: - with tf.name_scope(self.pre_layernorm.name): - self.pre_layernorm.build([None, None, self.embed_dim]) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "post_layernorm", None) is not None: - with tf.name_scope(self.post_layernorm.name): - self.post_layernorm.build([None, self.embed_dim]) - - -@keras_serializable -class TFCLIPVisionMainLayer(keras.layers.Layer): - config_class = CLIPVisionConfig - - def __init__(self, config: CLIPVisionConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.vision_model = TFCLIPVisionTransformer(config, name="vision_model") - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.vision_model.embeddings - - @unpack_inputs - def call( - self, - pixel_values: TFModelInputType | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - vision_model_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return vision_model_outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "vision_model", None) is not None: - with tf.name_scope(self.vision_model.name): - self.vision_model.build(None) - - -@keras_serializable -class TFCLIPMainLayer(keras.layers.Layer): - config_class = CLIPConfig - - def __init__(self, config: CLIPConfig, **kwargs): - super().__init__(**kwargs) - - if not isinstance(config.text_config, CLIPTextConfig): - raise TypeError( - "config.text_config is expected to be of type CLIPTextConfig but is of type" - f" {type(config.text_config)}." - ) - - if not isinstance(config.vision_config, CLIPVisionConfig): - raise TypeError( - "config.vision_config is expected to be of type CLIPVisionConfig but is of type" - f" {type(config.vision_config)}." - ) - - self.config = config - - text_config = config.text_config - vision_config = config.vision_config - - self.projection_dim = config.projection_dim - - self.text_model = TFCLIPTextTransformer(text_config, name="text_model") - self.vision_model = TFCLIPVisionTransformer(vision_config, name="vision_model") - - self.visual_projection = keras.layers.Dense( - units=self.projection_dim, - kernel_initializer=get_initializer(vision_config.hidden_size**-0.5 * self.config.initializer_factor), - use_bias=False, - name="visual_projection", - ) - - self.text_projection = keras.layers.Dense( - units=self.projection_dim, - kernel_initializer=get_initializer(text_config.hidden_size**-0.5 * self.config.initializer_factor), - use_bias=False, - name="text_projection", - ) - self.text_embed_dim = text_config.hidden_size - self.vision_embed_dim = vision_config.hidden_size - - def build(self, input_shape: tf.TensorShape = None): - self.logit_scale = self.add_weight( - shape=(1,), - initializer=keras.initializers.Constant(self.config.logit_scale_init_value), - trainable=True, - name="logit_scale", - ) - - if self.built: - return - self.built = True - if getattr(self, "text_model", None) is not None: - with tf.name_scope(self.text_model.name): - self.text_model.build(None) - if getattr(self, "vision_model", None) is not None: - with tf.name_scope(self.vision_model.name): - self.vision_model.build(None) - if getattr(self, "visual_projection", None) is not None: - with tf.name_scope(self.visual_projection.name): - self.visual_projection.build([None, None, self.vision_embed_dim]) - if getattr(self, "text_projection", None) is not None: - with tf.name_scope(self.text_projection.name): - self.text_projection.build([None, None, self.text_embed_dim]) - - @unpack_inputs - def get_text_features( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tf.Tensor: - if input_ids is None: - raise ValueError("You have to specify either input_ids") - - input_shape = shape_list(input_ids) - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - - text_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - pooled_output = text_outputs[1] - text_features = self.text_projection(inputs=pooled_output) - - return text_features - - @unpack_inputs - def get_image_features( - self, - pixel_values: TFModelInputType | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tf.Tensor: - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - pooled_output = vision_outputs[1] # pooled_output - image_features = self.visual_projection(inputs=pooled_output) - - return image_features - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - pixel_values: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - return_loss: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFCLIPOutput | tuple[tf.Tensor]: - if input_ids is None: - raise ValueError("You have to specify either input_ids") - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - input_shape = shape_list(input_ids) - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - text_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - image_embeds = vision_outputs[1] - image_embeds = self.visual_projection(inputs=image_embeds) - - text_embeds = text_outputs[1] - text_embeds = self.text_projection(inputs=text_embeds) - - # normalized features - image_embeds = image_embeds / tf.norm(tensor=image_embeds, ord="euclidean", axis=-1, keepdims=True) - text_embeds = text_embeds / tf.norm(tensor=text_embeds, ord="euclidean", axis=-1, keepdims=True) - - # cosine similarity as logits - logit_scale = tf.math.exp(self.logit_scale) - logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale - logits_per_image = tf.transpose(logits_per_text) - - loss = None - if return_loss: - loss = clip_loss(logits_per_text) - loss = tf.reshape(loss, (1,)) - - if not return_dict: - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return (loss,) + output if loss is not None else output - - return TFCLIPOutput( - loss=loss, - logits_per_image=logits_per_image, - logits_per_text=logits_per_text, - text_embeds=text_embeds, - image_embeds=image_embeds, - text_model_output=text_outputs, - vision_model_output=vision_outputs, - ) - - -class TFCLIPPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = CLIPConfig - base_model_prefix = "clip" - _keys_to_ignore_on_load_missing = [r"position_ids"] - _keys_to_ignore_on_load_unexpected = [r"position_ids"] - - -CLIP_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -CLIP_TEXT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - -CLIP_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`CLIPImageProcessor.__call__`] for details. output_attentions (`bool`, *optional*): Whether or not to - return the attentions tensors of all attention layers. See `attentions` under returned tensors for more - detail. This argument can be used only in eager mode, in graph mode the value in the config will be used - instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - -CLIP_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`CLIPImageProcessor.__call__`] for details. - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - return_loss (`bool`, *optional*): - Whether or not to return the contrastive loss. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -class TFCLIPTextModel(TFCLIPPreTrainedModel): - config_class = CLIPTextConfig - - def __init__(self, config: CLIPTextConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.clip = TFCLIPTextMainLayer(config, name="clip") - - @unpack_inputs - @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=CLIPTextConfig) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TFCLIPTextModel - - >>> model = TFCLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") - >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") - - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled (EOS token) states - ```""" - - outputs = self.clip( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "clip", None) is not None: - with tf.name_scope(self.clip.name): - self.clip.build(None) - - -class TFCLIPVisionModel(TFCLIPPreTrainedModel): - config_class = CLIPVisionConfig - main_input_name = "pixel_values" - - def __init__(self, config: CLIPVisionConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.clip = TFCLIPVisionMainLayer(config, name="clip") - - @unpack_inputs - @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=CLIPVisionConfig) - def call( - self, - pixel_values: TFModelInputType | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFCLIPVisionModel - - >>> model = TFCLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") - >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="tf") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled CLS states - ```""" - - outputs = self.clip( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "clip", None) is not None: - with tf.name_scope(self.clip.name): - self.clip.build(None) - - -@add_start_docstrings(CLIP_START_DOCSTRING) -class TFCLIPModel(TFCLIPPreTrainedModel): - config_class = CLIPConfig - - def __init__(self, config: CLIPConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.clip = TFCLIPMainLayer(config, name="clip") - - @unpack_inputs - @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def get_text_features( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tf.Tensor: - r""" - Returns: - text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying - the projection layer to the pooled output of [`TFCLIPTextModel`]. - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TFCLIPModel - - >>> model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32") - >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") - - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf") - >>> text_features = model.get_text_features(**inputs) - ```""" - - text_features = self.clip.get_text_features( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - return text_features - - @unpack_inputs - @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) - def get_image_features( - self, - pixel_values: TFModelInputType | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tf.Tensor: - r""" - Returns: - image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying - the projection layer to the pooled output of [`TFCLIPVisionModel`]. - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFCLIPModel - - >>> model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32") - >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="tf") - - >>> image_features = model.get_image_features(**inputs) - ```""" - - image_features = self.clip.get_image_features( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - return image_features - - @unpack_inputs - @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFCLIPOutput, config_class=CLIPConfig) - def call( - self, - input_ids: TFModelInputType | None = None, - pixel_values: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - return_loss: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFCLIPOutput | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> import tensorflow as tf - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFCLIPModel - - >>> model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32") - >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor( - ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="tf", padding=True - ... ) - - >>> outputs = model(**inputs) - >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score - >>> probs = tf.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities - ```""" - - outputs = self.clip( - input_ids=input_ids, - pixel_values=pixel_values, - attention_mask=attention_mask, - position_ids=position_ids, - return_loss=return_loss, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - return outputs - - def serving_output(self, output: TFCLIPOutput) -> TFCLIPOutput: - # TODO: As is this currently fails with saved_model=True, because - # TensorFlow cannot trace through nested dataclasses. Reference: - # https://github.com/huggingface/transformers/pull/16886 - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "clip", None) is not None: - with tf.name_scope(self.clip.name): - self.clip.build(None) - - -__all__ = ["TFCLIPModel", "TFCLIPPreTrainedModel", "TFCLIPTextModel", "TFCLIPVisionModel"] diff --git a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py deleted file mode 100644 index 3d4ff779874b..000000000000 --- a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py +++ /dev/null @@ -1,57 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert ConvBERT checkpoint.""" - -import argparse - -from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path): - conf = ConvBertConfig.from_json_file(convbert_config_file) - model = ConvBertModel(conf) - - model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path) - model.save_pretrained(pytorch_dump_path) - - tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True) - tf_model.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--convbert_config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained ConvBERT model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/convbert/modeling_tf_convbert.py b/src/transformers/models/convbert/modeling_tf_convbert.py deleted file mode 100644 index 47c720f5c12c..000000000000 --- a/src/transformers/models/convbert/modeling_tf_convbert.py +++ /dev/null @@ -1,1474 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 ConvBERT model.""" - -from __future__ import annotations - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFSequenceSummary, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_convbert import ConvBertConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "YituTech/conv-bert-base" -_CONFIG_FOR_DOC = "ConvBertConfig" - - -# Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings with Albert->ConvBert -class TFConvBertEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config: ConvBertConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = config.embedding_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.embedding_size]) - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call - def call( - self, - input_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - past_key_values_length=0, - training: bool = False, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - if input_ids is None and inputs_embeds is None: - raise ValueError("Need to provide either `input_ids` or `input_embeds`.") - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - position_ids = tf.expand_dims( - tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 - ) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -class TFConvBertSelfAttention(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})" - ) - - new_num_attention_heads = int(config.num_attention_heads / config.head_ratio) - if new_num_attention_heads < 1: - self.head_ratio = config.num_attention_heads - num_attention_heads = 1 - else: - num_attention_heads = new_num_attention_heads - self.head_ratio = config.head_ratio - - self.num_attention_heads = num_attention_heads - self.conv_kernel_size = config.conv_kernel_size - - if config.hidden_size % self.num_attention_heads != 0: - raise ValueError("hidden_size should be divisible by num_attention_heads") - - self.attention_head_size = config.hidden_size // config.num_attention_heads - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - - self.key_conv_attn_layer = keras.layers.SeparableConv1D( - self.all_head_size, - self.conv_kernel_size, - padding="same", - activation=None, - depthwise_initializer=get_initializer(1 / self.conv_kernel_size), - pointwise_initializer=get_initializer(config.initializer_range), - name="key_conv_attn_layer", - ) - - self.conv_kernel_layer = keras.layers.Dense( - self.num_attention_heads * self.conv_kernel_size, - activation=None, - name="conv_kernel_layer", - kernel_initializer=get_initializer(config.initializer_range), - ) - - self.conv_out_layer = keras.layers.Dense( - self.all_head_size, - activation=None, - name="conv_out_layer", - kernel_initializer=get_initializer(config.initializer_range), - ) - - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.config = config - - def transpose_for_scores(self, x, batch_size): - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) - return tf.transpose(x, perm=[0, 2, 1, 3]) - - def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False): - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - mixed_key_conv_attn_layer = self.key_conv_attn_layer(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - conv_attn_layer = tf.multiply(mixed_key_conv_attn_layer, mixed_query_layer) - - conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer) - conv_kernel_layer = tf.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1]) - conv_kernel_layer = stable_softmax(conv_kernel_layer, axis=1) - - paddings = tf.constant( - [ - [ - 0, - 0, - ], - [int((self.conv_kernel_size - 1) / 2), int((self.conv_kernel_size - 1) / 2)], - [0, 0], - ] - ) - - conv_out_layer = self.conv_out_layer(hidden_states) - conv_out_layer = tf.reshape(conv_out_layer, [batch_size, -1, self.all_head_size]) - conv_out_layer = tf.pad(conv_out_layer, paddings, "CONSTANT") - - unfold_conv_out_layer = tf.stack( - [ - tf.slice(conv_out_layer, [0, i, 0], [batch_size, shape_list(mixed_query_layer)[1], self.all_head_size]) - for i in range(self.conv_kernel_size) - ], - axis=-1, - ) - - conv_out_layer = tf.reshape(unfold_conv_out_layer, [-1, self.attention_head_size, self.conv_kernel_size]) - - conv_out_layer = tf.matmul(conv_out_layer, conv_kernel_layer) - conv_out_layer = tf.reshape(conv_out_layer, [-1, self.all_head_size]) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = tf.matmul( - query_layer, key_layer, transpose_b=True - ) # (batch size, num_heads, seq_len_q, seq_len_k) - dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores - attention_scores = attention_scores / tf.math.sqrt(dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFBertModel call() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - value_layer = tf.reshape( - mixed_value_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size] - ) - value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) - - context_layer = tf.matmul(attention_probs, value_layer) - context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) - - conv_out = tf.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size]) - context_layer = tf.concat([context_layer, conv_out], 2) - context_layer = tf.reshape( - context_layer, (batch_size, -1, self.head_ratio * self.all_head_size) - ) # (batch_size, seq_len_q, all_head_size) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - if getattr(self, "key_conv_attn_layer", None) is not None: - with tf.name_scope(self.key_conv_attn_layer.name): - self.key_conv_attn_layer.build([None, None, self.config.hidden_size]) - if getattr(self, "conv_kernel_layer", None) is not None: - with tf.name_scope(self.conv_kernel_layer.name): - self.conv_kernel_layer.build([None, None, self.all_head_size]) - if getattr(self, "conv_out_layer", None) is not None: - with tf.name_scope(self.conv_out_layer.name): - self.conv_out_layer.build([None, None, self.config.hidden_size]) - - -class TFConvBertSelfOutput(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states, input_tensor, training=False): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFConvBertAttention(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFConvBertSelfAttention(config, name="self") - self.dense_output = TFConvBertSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False): - self_outputs = self.self_attention( - input_tensor, attention_mask, head_mask, output_attentions, training=training - ) - attention_output = self.dense_output(self_outputs[0], input_tensor, training=training) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -class GroupedLinearLayer(keras.layers.Layer): - def __init__(self, input_size, output_size, num_groups, kernel_initializer, **kwargs): - super().__init__(**kwargs) - self.input_size = input_size - self.output_size = output_size - self.num_groups = num_groups - self.kernel_initializer = kernel_initializer - self.group_in_dim = self.input_size // self.num_groups - self.group_out_dim = self.output_size // self.num_groups - - def build(self, input_shape=None): - self.kernel = self.add_weight( - "kernel", - shape=[self.group_out_dim, self.group_in_dim, self.num_groups], - initializer=self.kernel_initializer, - trainable=True, - ) - - self.bias = self.add_weight( - "bias", shape=[self.output_size], initializer=self.kernel_initializer, dtype=self.dtype, trainable=True - ) - super().build(input_shape) - - def call(self, hidden_states): - batch_size = shape_list(hidden_states)[0] - x = tf.transpose(tf.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]), [1, 0, 2]) - x = tf.matmul(x, tf.transpose(self.kernel, [2, 1, 0])) - x = tf.transpose(x, [1, 0, 2]) - x = tf.reshape(x, [batch_size, -1, self.output_size]) - x = tf.nn.bias_add(value=x, bias=self.bias) - return x - - -class TFConvBertIntermediate(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - if config.num_groups == 1: - self.dense = keras.layers.Dense( - config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - else: - self.dense = GroupedLinearLayer( - config.hidden_size, - config.intermediate_size, - num_groups=config.num_groups, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFConvBertOutput(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - if config.num_groups == 1: - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - else: - self.dense = GroupedLinearLayer( - config.intermediate_size, - config.hidden_size, - num_groups=config.num_groups, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states, input_tensor, training=False): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - - -class TFConvBertLayer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.attention = TFConvBertAttention(config, name="attention") - self.intermediate = TFConvBertIntermediate(config, name="intermediate") - self.bert_output = TFConvBertOutput(config, name="output") - - def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False): - attention_outputs = self.attention( - hidden_states, attention_mask, head_mask, output_attentions, training=training - ) - attention_output = attention_outputs[0] - intermediate_output = self.intermediate(attention_output) - layer_output = self.bert_output(intermediate_output, attention_output, training=training) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - - -class TFConvBertEncoder(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.layer = [TFConvBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states, - attention_mask, - head_mask, - output_attentions, - output_hidden_states, - return_dict, - training=False, - ): - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, attention_mask, head_mask[i], output_attentions, training=training - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFConvBertPredictionHeadTransform(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.transform_act_fn = get_tf_activation(config.hidden_act) - else: - self.transform_act_fn = config.hidden_act - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -@keras_serializable -class TFConvBertMainLayer(keras.layers.Layer): - config_class = ConvBertConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.embeddings = TFConvBertEmbeddings(config, name="embeddings") - - if config.embedding_size != config.hidden_size: - self.embeddings_project = keras.layers.Dense(config.hidden_size, name="embeddings_project") - - self.encoder = TFConvBertEncoder(config, name="encoder") - self.config = config - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, value): - self.embeddings.weight = value - self.embeddings.vocab_size = value.shape[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - def get_extended_attention_mask(self, attention_mask, input_shape, dtype): - if attention_mask is None: - attention_mask = tf.fill(input_shape, 1) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - return extended_attention_mask - - def get_head_mask(self, head_mask): - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - return head_mask - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(input_shape, 1) - - if token_type_ids is None: - token_type_ids = tf.fill(input_shape, 0) - - hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype) - head_mask = self.get_head_mask(head_mask) - - if hasattr(self, "embeddings_project"): - hidden_states = self.embeddings_project(hidden_states, training=training) - - hidden_states = self.encoder( - hidden_states, - extended_attention_mask, - head_mask, - output_attentions, - output_hidden_states, - return_dict, - training=training, - ) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "embeddings_project", None) is not None: - with tf.name_scope(self.embeddings_project.name): - self.embeddings_project.build([None, None, self.config.embedding_size]) - - -class TFConvBertPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ConvBertConfig - base_model_prefix = "convbert" - - -CONVBERT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`ConvBertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -CONVBERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare ConvBERT Model transformer outputting raw hidden-states without any specific head on top.", - CONVBERT_START_DOCSTRING, -) -class TFConvBertModel(TFConvBertPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.convbert = TFConvBertMainLayer(config, name="convbert") - - @unpack_inputs - @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.array | tf.Tensor | None = None, - token_type_ids: np.array | tf.Tensor | None = None, - position_ids: np.array | tf.Tensor | None = None, - head_mask: np.array | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - outputs = self.convbert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convbert", None) is not None: - with tf.name_scope(self.convbert.name): - self.convbert.build(None) - - -class TFConvBertMaskedLMHead(keras.layers.Layer): - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = config.embedding_size - self.input_embeddings = input_embeddings - - def build(self, input_shape): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - super().build(input_shape) - - def get_output_embeddings(self): - return self.input_embeddings - - def set_output_embeddings(self, value): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -class TFConvBertGeneratorPredictions(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dense = keras.layers.Dense(config.embedding_size, name="dense") - self.config = config - - def call(self, generator_hidden_states, training=False): - hidden_states = self.dense(generator_hidden_states) - hidden_states = get_tf_activation("gelu")(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.embedding_size]) - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING) -class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, **kwargs) - - self.config = config - self.convbert = TFConvBertMainLayer(config, name="convbert") - self.generator_predictions = TFConvBertGeneratorPredictions(config, name="generator_predictions") - - if isinstance(config.hidden_act, str): - self.activation = get_tf_activation(config.hidden_act) - else: - self.activation = config.hidden_act - - self.generator_lm_head = TFConvBertMaskedLMHead(config, self.convbert.embeddings, name="generator_lm_head") - - def get_lm_head(self): - return self.generator_lm_head - - def get_prefix_bias_name(self): - return self.name + "/" + self.generator_lm_head.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFMaskedLMOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - generator_hidden_states = self.convbert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - generator_sequence_output = generator_hidden_states[0] - prediction_scores = self.generator_predictions(generator_sequence_output, training=training) - prediction_scores = self.generator_lm_head(prediction_scores, training=training) - loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - - if not return_dict: - output = (prediction_scores,) + generator_hidden_states[1:] - - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=generator_hidden_states.hidden_states, - attentions=generator_hidden_states.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convbert", None) is not None: - with tf.name_scope(self.convbert.name): - self.convbert.build(None) - if getattr(self, "generator_predictions", None) is not None: - with tf.name_scope(self.generator_predictions.name): - self.generator_predictions.build(None) - if getattr(self, "generator_lm_head", None) is not None: - with tf.name_scope(self.generator_lm_head.name): - self.generator_lm_head.build(None) - - -class TFConvBertClassificationHead(keras.layers.Layer): - """Head for sentence-level classification tasks.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.out_proj = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" - ) - - self.config = config - - def call(self, hidden_states, **kwargs): - x = hidden_states[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x) - x = self.dense(x) - x = get_tf_activation(self.config.hidden_act)(x) - x = self.dropout(x) - x = self.out_proj(x) - - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - ConvBERT Model transformer with a sequence classification/regression head on top e.g., for GLUE tasks. - """, - CONVBERT_START_DOCSTRING, -) -class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.convbert = TFConvBertMainLayer(config, name="convbert") - self.classifier = TFConvBertClassificationHead(config, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFSequenceClassifierOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.convbert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - logits = self.classifier(outputs[0], training=training) - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[1:] - - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convbert", None) is not None: - with tf.name_scope(self.convbert.name): - self.convbert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - ConvBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - CONVBERT_START_DOCSTRING, -) -class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.convbert = TFConvBertMainLayer(config, name="convbert") - self.sequence_summary = TFSequenceSummary( - config, initializer_range=config.initializer_range, name="sequence_summary" - ) - self.classifier = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward( - CONVBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") - ) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFMultipleChoiceModelOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - flat_inputs_embeds = ( - tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - outputs = self.convbert( - flat_input_ids, - flat_attention_mask, - flat_token_type_ids, - flat_position_ids, - head_mask, - flat_inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - logits = self.sequence_summary(outputs[0], training=training) - logits = self.classifier(logits) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[1:] - - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convbert", None) is not None: - with tf.name_scope(self.convbert.name): - self.convbert.build(None) - if getattr(self, "sequence_summary", None) is not None: - with tf.name_scope(self.sequence_summary.name): - self.sequence_summary.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - ConvBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - CONVBERT_START_DOCSTRING, -) -class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.convbert = TFConvBertMainLayer(config, name="convbert") - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFTokenClassifierOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.convbert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convbert", None) is not None: - with tf.name_scope(self.convbert.name): - self.convbert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - ConvBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - CONVBERT_START_DOCSTRING, -) -class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnsweringLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.convbert = TFConvBertMainLayer(config, name="convbert") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: tf.Tensor | None = None, - end_positions: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFQuestionAnsweringModelOutput: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.convbert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convbert", None) is not None: - with tf.name_scope(self.convbert.name): - self.convbert.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFConvBertForMaskedLM", - "TFConvBertForMultipleChoice", - "TFConvBertForQuestionAnswering", - "TFConvBertForSequenceClassification", - "TFConvBertForTokenClassification", - "TFConvBertLayer", - "TFConvBertModel", - "TFConvBertPreTrainedModel", -] diff --git a/src/transformers/models/convnext/modeling_tf_convnext.py b/src/transformers/models/convnext/modeling_tf_convnext.py deleted file mode 100644 index 7306877466d9..000000000000 --- a/src/transformers/models/convnext/modeling_tf_convnext.py +++ /dev/null @@ -1,667 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 ConvNext model.""" - -from __future__ import annotations - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput -from ...modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_convnext import ConvNextConfig - - -logger = logging.get_logger(__name__) - - -_CONFIG_FOR_DOC = "ConvNextConfig" -_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224" - - -class TFConvNextDropPath(keras.layers.Layer): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - References: - (1) github.com:rwightman/pytorch-image-models - """ - - def __init__(self, drop_path: float, **kwargs): - super().__init__(**kwargs) - self.drop_path = drop_path - - def call(self, x: tf.Tensor, training=None): - if training: - keep_prob = 1 - self.drop_path - shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) - random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) - random_tensor = tf.floor(random_tensor) - return (x / keep_prob) * random_tensor - return x - - -class TFConvNextEmbeddings(keras.layers.Layer): - """This class is comparable to (and inspired by) the SwinEmbeddings class - found in src/transformers/models/swin/modeling_swin.py. - """ - - def __init__(self, config: ConvNextConfig, **kwargs): - super().__init__(**kwargs) - self.patch_embeddings = keras.layers.Conv2D( - filters=config.hidden_sizes[0], - kernel_size=config.patch_size, - strides=config.patch_size, - name="patch_embeddings", - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer=keras.initializers.Zeros(), - ) - self.layernorm = keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm") - self.num_channels = config.num_channels - self.config = config - - def call(self, pixel_values): - if isinstance(pixel_values, dict): - pixel_values = pixel_values["pixel_values"] - - tf.debugging.assert_equal( - shape_list(pixel_values)[1], - self.num_channels, - message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.", - ) - - # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - # shape = (batch_size, in_height, in_width, in_channels) - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - - embeddings = self.patch_embeddings(pixel_values) - embeddings = self.layernorm(embeddings) - return embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "patch_embeddings", None) is not None: - with tf.name_scope(self.patch_embeddings.name): - self.patch_embeddings.build([None, None, None, self.config.num_channels]) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, None, self.config.hidden_sizes[0]]) - - -class TFConvNextLayer(keras.layers.Layer): - """This corresponds to the `Block` class in the original implementation. - - There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C, - H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back - - The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow - NHWC ordering, we can just apply the operations straight-away without the permutation. - - Args: - config ([`ConvNextConfig`]): Model configuration class. - dim (`int`): Number of input channels. - drop_path (`float`): Stochastic depth rate. Default: 0.0. - """ - - def __init__(self, config, dim, drop_path=0.0, **kwargs): - super().__init__(**kwargs) - self.dim = dim - self.config = config - self.dwconv = keras.layers.Conv2D( - filters=dim, - kernel_size=7, - padding="same", - groups=dim, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer="zeros", - name="dwconv", - ) # depthwise conv - self.layernorm = keras.layers.LayerNormalization( - epsilon=1e-6, - name="layernorm", - ) - self.pwconv1 = keras.layers.Dense( - units=4 * dim, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer="zeros", - name="pwconv1", - ) # pointwise/1x1 convs, implemented with linear layers - self.act = get_tf_activation(config.hidden_act) - self.pwconv2 = keras.layers.Dense( - units=dim, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer="zeros", - name="pwconv2", - ) - # Using `layers.Activation` instead of `tf.identity` to better control `training` - # behaviour. - self.drop_path = ( - TFConvNextDropPath(drop_path, name="drop_path") - if drop_path > 0.0 - else keras.layers.Activation("linear", name="drop_path") - ) - - def build(self, input_shape: tf.TensorShape = None): - # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa) - self.layer_scale_parameter = ( - self.add_weight( - shape=(self.dim,), - initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value), - trainable=True, - name="layer_scale_parameter", - ) - if self.config.layer_scale_init_value > 0 - else None - ) - - if self.built: - return - self.built = True - if getattr(self, "dwconv", None) is not None: - with tf.name_scope(self.dwconv.name): - self.dwconv.build([None, None, None, self.dim]) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, None, self.dim]) - if getattr(self, "pwconv1", None) is not None: - with tf.name_scope(self.pwconv1.name): - self.pwconv1.build([None, None, self.dim]) - if getattr(self, "pwconv2", None) is not None: - with tf.name_scope(self.pwconv2.name): - self.pwconv2.build([None, None, 4 * self.dim]) - if getattr(self, "drop_path", None) is not None: - with tf.name_scope(self.drop_path.name): - self.drop_path.build(None) - - def call(self, hidden_states, training=False): - input = hidden_states - x = self.dwconv(hidden_states) - x = self.layernorm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.pwconv2(x) - - if self.layer_scale_parameter is not None: - x = self.layer_scale_parameter * x - - x = input + self.drop_path(x, training=training) - return x - - -class TFConvNextStage(keras.layers.Layer): - """ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks. - - Args: - config (`ConvNextV2Config`): - Model configuration class. - in_channels (`int`): - Number of input channels. - out_channels (`int`): - Number of output channels. - depth (`int`): - Number of residual blocks. - drop_path_rates(`list[float]`): - Stochastic depth rates for each layer. - """ - - def __init__( - self, - config: ConvNextConfig, - in_channels: int, - out_channels: int, - kernel_size: int = 2, - stride: int = 2, - depth: int = 2, - drop_path_rates: list[float] | None = None, - **kwargs, - ): - super().__init__(**kwargs) - if in_channels != out_channels or stride > 1: - self.downsampling_layer = [ - keras.layers.LayerNormalization( - epsilon=1e-6, - name="downsampling_layer.0", - ), - # Inputs to this layer will follow NHWC format since we - # transposed the inputs from NCHW to NHWC in the `TFConvNextEmbeddings` - # layer. All the outputs throughout the model will be in NHWC - # from this point on until the output where we again change to - # NCHW. - keras.layers.Conv2D( - filters=out_channels, - kernel_size=kernel_size, - strides=stride, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer=keras.initializers.Zeros(), - name="downsampling_layer.1", - ), - ] - else: - self.downsampling_layer = [tf.identity] - - drop_path_rates = drop_path_rates or [0.0] * depth - self.layers = [ - TFConvNextLayer( - config, - dim=out_channels, - drop_path=drop_path_rates[j], - name=f"layers.{j}", - ) - for j in range(depth) - ] - self.in_channels = in_channels - self.out_channels = out_channels - self.stride = stride - - def call(self, hidden_states): - for layer in self.downsampling_layer: - hidden_states = layer(hidden_states) - for layer in self.layers: - hidden_states = layer(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - if self.in_channels != self.out_channels or self.stride > 1: - with tf.name_scope(self.downsampling_layer[0].name): - self.downsampling_layer[0].build([None, None, None, self.in_channels]) - with tf.name_scope(self.downsampling_layer[1].name): - self.downsampling_layer[1].build([None, None, None, self.in_channels]) - - -class TFConvNextEncoder(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.stages = [] - drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths)) - drop_path_rates = tf.split(drop_path_rates, config.depths) - drop_path_rates = [x.numpy().tolist() for x in drop_path_rates] - prev_chs = config.hidden_sizes[0] - for i in range(config.num_stages): - out_chs = config.hidden_sizes[i] - stage = TFConvNextStage( - config, - in_channels=prev_chs, - out_channels=out_chs, - stride=2 if i > 0 else 1, - depth=config.depths[i], - drop_path_rates=drop_path_rates[i], - name=f"stages.{i}", - ) - self.stages.append(stage) - prev_chs = out_chs - - def call(self, hidden_states, output_hidden_states=False, return_dict=True): - all_hidden_states = () if output_hidden_states else None - - for i, layer_module in enumerate(self.stages): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - hidden_states = layer_module(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) - - return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states) - - def build(self, input_shape=None): - for stage in self.stages: - with tf.name_scope(stage.name): - stage.build(None) - - -@keras_serializable -class TFConvNextMainLayer(keras.layers.Layer): - config_class = ConvNextConfig - - def __init__(self, config: ConvNextConfig, add_pooling_layer: bool = True, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embeddings = TFConvNextEmbeddings(config, name="embeddings") - self.encoder = TFConvNextEncoder(config, name="encoder") - self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - # We are setting the `data_format` like so because from here on we will revert to the - # NCHW output format - self.pooler = keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None - - @unpack_inputs - def call( - self, - pixel_values: TFModelInputType | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - embedding_output = self.embeddings(pixel_values, training=training) - - encoder_outputs = self.encoder( - embedding_output, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - last_hidden_state = encoder_outputs[0] - # Change to NCHW output format have uniformity in the modules - last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2)) - pooled_output = self.layernorm(self.pooler(last_hidden_state)) - - # Change the other hidden state outputs to NCHW as well - if output_hidden_states: - hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]) - - if not return_dict: - hidden_states = hidden_states if output_hidden_states else () - return (last_hidden_state, pooled_output) + hidden_states - - return TFBaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, self.config.hidden_sizes[-1]]) - - -class TFConvNextPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ConvNextConfig - base_model_prefix = "convnext" - main_input_name = "pixel_values" - - -CONVNEXT_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -CONVNEXT_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`ConvNextImageProcessor.__call__`] for details. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. -""" - - -@add_start_docstrings( - "The bare ConvNext model outputting raw features without any specific head on top.", - CONVNEXT_START_DOCSTRING, -) -class TFConvNextModel(TFConvNextPreTrainedModel): - def __init__(self, config, *inputs, add_pooling_layer=True, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext") - - @unpack_inputs - @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: TFModelInputType | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, TFConvNextModel - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224") - >>> model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224") - - >>> inputs = image_processor(images=image, return_tensors="tf") - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - outputs = self.convnext( - pixel_values=pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return (outputs[0],) + outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=outputs.last_hidden_state, - pooler_output=outputs.pooler_output, - hidden_states=outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convnext", None) is not None: - with tf.name_scope(self.convnext.name): - self.convnext.build(None) - - -@add_start_docstrings( - """ - ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for - ImageNet. - """, - CONVNEXT_START_DOCSTRING, -) -class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: ConvNextConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.convnext = TFConvNextMainLayer(config, name="convnext") - - # Classifier head - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer="zeros", - name="classifier", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: TFModelInputType | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, TFConvNextForImageClassification - >>> import tensorflow as tf - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224") - >>> model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224") - - >>> inputs = image_processor(images=image, return_tensors="tf") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] - >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)]) - ```""" - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - outputs = self.convnext( - pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - pooled_output = outputs.pooler_output if return_dict else outputs[1] - - logits = self.classifier(pooled_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convnext", None) is not None: - with tf.name_scope(self.convnext.name): - self.convnext.build(None) - if getattr(self, "classifier", None) is not None: - if hasattr(self.classifier, "name"): - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_sizes[-1]]) - - -__all__ = ["TFConvNextForImageClassification", "TFConvNextModel", "TFConvNextPreTrainedModel"] diff --git a/src/transformers/models/convnextv2/modeling_tf_convnextv2.py b/src/transformers/models/convnextv2/modeling_tf_convnextv2.py deleted file mode 100644 index d370c3008d47..000000000000 --- a/src/transformers/models/convnextv2/modeling_tf_convnextv2.py +++ /dev/null @@ -1,681 +0,0 @@ -# coding=utf-8 -# Copyright 2023 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 ConvNextV2 model.""" - -from __future__ import annotations - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithNoAttention, - TFBaseModelOutputWithPooling, - TFBaseModelOutputWithPoolingAndNoAttention, - TFImageClassifierOutputWithNoAttention, -) -from ...modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_convnextv2 import ConvNextV2Config - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "ConvNextV2Config" - -# Base docstring -_CHECKPOINT_FOR_DOC = "facebook/convnextv2-tiny-1k-224" -_EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7] - -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "facebook/convnextv2-tiny-1k-224" -_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" - - -# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->ConvNextV2 -class TFConvNextV2DropPath(keras.layers.Layer): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - References: - (1) github.com:rwightman/pytorch-image-models - """ - - def __init__(self, drop_path: float, **kwargs): - super().__init__(**kwargs) - self.drop_path = drop_path - - def call(self, x: tf.Tensor, training=None): - if training: - keep_prob = 1 - self.drop_path - shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) - random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) - random_tensor = tf.floor(random_tensor) - return (x / keep_prob) * random_tensor - return x - - -class TFConvNextV2GRN(keras.layers.Layer): - """GRN (Global Response Normalization) layer""" - - def __init__(self, config: ConvNextV2Config, dim: int, **kwargs): - super().__init__(**kwargs) - self.dim = dim - - def build(self, input_shape: tf.TensorShape = None): - # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa) - self.weight = self.add_weight( - name="weight", - shape=(1, 1, 1, self.dim), - initializer=keras.initializers.Zeros(), - ) - self.bias = self.add_weight( - name="bias", - shape=(1, 1, 1, self.dim), - initializer=keras.initializers.Zeros(), - ) - return super().build(input_shape) - - def call(self, hidden_states: tf.Tensor): - global_features = tf.norm(hidden_states, ord="euclidean", axis=(1, 2), keepdims=True) - norm_features = global_features / (tf.reduce_mean(global_features, axis=-1, keepdims=True) + 1e-6) - hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states - return hidden_states - - -# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextEmbeddings with ConvNext->ConvNextV2 -class TFConvNextV2Embeddings(keras.layers.Layer): - """This class is comparable to (and inspired by) the SwinEmbeddings class - found in src/transformers/models/swin/modeling_swin.py. - """ - - def __init__(self, config: ConvNextV2Config, **kwargs): - super().__init__(**kwargs) - self.patch_embeddings = keras.layers.Conv2D( - filters=config.hidden_sizes[0], - kernel_size=config.patch_size, - strides=config.patch_size, - name="patch_embeddings", - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer=keras.initializers.Zeros(), - ) - self.layernorm = keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm") - self.num_channels = config.num_channels - self.config = config - - def call(self, pixel_values): - if isinstance(pixel_values, dict): - pixel_values = pixel_values["pixel_values"] - - tf.debugging.assert_equal( - shape_list(pixel_values)[1], - self.num_channels, - message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.", - ) - - # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - # shape = (batch_size, in_height, in_width, in_channels) - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - - embeddings = self.patch_embeddings(pixel_values) - embeddings = self.layernorm(embeddings) - return embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "patch_embeddings", None) is not None: - with tf.name_scope(self.patch_embeddings.name): - self.patch_embeddings.build([None, None, None, self.config.num_channels]) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, None, self.config.hidden_sizes[0]]) - - -class TFConvNextV2Layer(keras.layers.Layer): - """This corresponds to the `Block` class in the original implementation. - - There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C, - H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back - - The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow - NHWC ordering, we can just apply the operations straight-away without the permutation. - - Args: - config (`ConvNextV2Config`): - Model configuration class. - dim (`int`): - Number of input channels. - drop_path (`float`, *optional*, defaults to 0.0): - Stochastic depth rate. - """ - - def __init__(self, config: ConvNextV2Config, dim: int, drop_path: float = 0.0, **kwargs): - super().__init__(**kwargs) - self.dim = dim - self.config = config - self.dwconv = keras.layers.Conv2D( - filters=dim, - kernel_size=7, - padding="same", - groups=dim, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer=keras.initializers.Zeros(), - name="dwconv", - ) # depthwise conv - self.layernorm = keras.layers.LayerNormalization( - epsilon=1e-6, - name="layernorm", - ) - self.pwconv1 = keras.layers.Dense( - units=4 * dim, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer=keras.initializers.Zeros(), - name="pwconv1", - ) # pointwise/1x1 convs, implemented with linear layers - self.act = get_tf_activation(config.hidden_act) - self.grn = TFConvNextV2GRN(config, 4 * dim, dtype=tf.float32, name="grn") - self.pwconv2 = keras.layers.Dense( - units=dim, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer=keras.initializers.Zeros(), - name="pwconv2", - ) - # Using `layers.Activation` instead of `tf.identity` to better control `training` - # behaviour. - self.drop_path = ( - TFConvNextV2DropPath(drop_path, name="drop_path") - if drop_path > 0.0 - else keras.layers.Activation("linear", name="drop_path") - ) - - def call(self, hidden_states, training=False): - input = hidden_states - x = self.dwconv(hidden_states) - x = self.layernorm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.grn(x) - x = self.pwconv2(x) - x = self.drop_path(x, training=training) - x = input + x - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dwconv", None) is not None: - with tf.name_scope(self.dwconv.name): - self.dwconv.build([None, None, None, self.dim]) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, None, self.dim]) - if getattr(self, "pwconv1", None) is not None: - with tf.name_scope(self.pwconv1.name): - self.pwconv1.build([None, None, self.dim]) - if getattr(self, "grn", None) is not None: - with tf.name_scope(self.grn.name): - self.grn.build(None) - if getattr(self, "pwconv2", None) is not None: - with tf.name_scope(self.pwconv2.name): - self.pwconv2.build([None, None, 4 * self.dim]) - if getattr(self, "drop_path", None) is not None: - with tf.name_scope(self.drop_path.name): - self.drop_path.build(None) - - -# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextStage with ConvNext->ConvNextV2 -class TFConvNextV2Stage(keras.layers.Layer): - """ConvNextV2 stage, consisting of an optional downsampling layer + multiple residual blocks. - - Args: - config (`ConvNextV2V2Config`): - Model configuration class. - in_channels (`int`): - Number of input channels. - out_channels (`int`): - Number of output channels. - depth (`int`): - Number of residual blocks. - drop_path_rates(`list[float]`): - Stochastic depth rates for each layer. - """ - - def __init__( - self, - config: ConvNextV2Config, - in_channels: int, - out_channels: int, - kernel_size: int = 2, - stride: int = 2, - depth: int = 2, - drop_path_rates: list[float] | None = None, - **kwargs, - ): - super().__init__(**kwargs) - if in_channels != out_channels or stride > 1: - self.downsampling_layer = [ - keras.layers.LayerNormalization( - epsilon=1e-6, - name="downsampling_layer.0", - ), - # Inputs to this layer will follow NHWC format since we - # transposed the inputs from NCHW to NHWC in the `TFConvNextV2Embeddings` - # layer. All the outputs throughout the model will be in NHWC - # from this point on until the output where we again change to - # NCHW. - keras.layers.Conv2D( - filters=out_channels, - kernel_size=kernel_size, - strides=stride, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer=keras.initializers.Zeros(), - name="downsampling_layer.1", - ), - ] - else: - self.downsampling_layer = [tf.identity] - - drop_path_rates = drop_path_rates or [0.0] * depth - self.layers = [ - TFConvNextV2Layer( - config, - dim=out_channels, - drop_path=drop_path_rates[j], - name=f"layers.{j}", - ) - for j in range(depth) - ] - self.in_channels = in_channels - self.out_channels = out_channels - self.stride = stride - - def call(self, hidden_states): - for layer in self.downsampling_layer: - hidden_states = layer(hidden_states) - for layer in self.layers: - hidden_states = layer(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - if self.in_channels != self.out_channels or self.stride > 1: - with tf.name_scope(self.downsampling_layer[0].name): - self.downsampling_layer[0].build([None, None, None, self.in_channels]) - with tf.name_scope(self.downsampling_layer[1].name): - self.downsampling_layer[1].build([None, None, None, self.in_channels]) - - -class TFConvNextV2Encoder(keras.layers.Layer): - def __init__(self, config: ConvNextV2Config, **kwargs): - super().__init__(**kwargs) - self.stages = [] - drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths)) - drop_path_rates = tf.split(drop_path_rates, config.depths) - drop_path_rates = [x.numpy().tolist() for x in drop_path_rates] - prev_chs = config.hidden_sizes[0] - for i in range(config.num_stages): - out_chs = config.hidden_sizes[i] - stage = TFConvNextV2Stage( - config, - in_channels=prev_chs, - out_channels=out_chs, - stride=2 if i > 0 else 1, - depth=config.depths[i], - drop_path_rates=drop_path_rates[i], - name=f"stages.{i}", - ) - self.stages.append(stage) - prev_chs = out_chs - - def call( - self, - hidden_states: tf.Tensor, - output_hidden_states: bool | None = False, - return_dict: bool | None = True, - ) -> tuple | TFBaseModelOutputWithNoAttention: - all_hidden_states = () if output_hidden_states else None - - for i, layer_module in enumerate(self.stages): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - hidden_states = layer_module(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) - - return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states) - - def build(self, input_shape=None): - for stage in self.stages: - with tf.name_scope(stage.name): - stage.build(None) - - -@keras_serializable -class TFConvNextV2MainLayer(keras.layers.Layer): - config_class = ConvNextV2Config - - def __init__(self, config: ConvNextV2Config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embeddings = TFConvNextV2Embeddings(config, name="embeddings") - self.encoder = TFConvNextV2Encoder(config, name="encoder") - self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - # We are setting the `data_format` like so because from here on we will revert to the - # NCHW output format - self.pooler = keras.layers.GlobalAvgPool2D(data_format="channels_last") - - @unpack_inputs - def call( - self, - pixel_values: TFModelInputType | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - embedding_output = self.embeddings(pixel_values, training=training) - - encoder_outputs = self.encoder( - embedding_output, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - last_hidden_state = encoder_outputs[0] - - # Change to NCHW output format have uniformity in the modules - pooled_output = self.pooler(last_hidden_state) - last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2)) - pooled_output = self.layernorm(pooled_output) - - # Change the other hidden state outputs to NCHW as well - if output_hidden_states: - hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]) - - if not return_dict: - hidden_states = hidden_states if output_hidden_states else () - return (last_hidden_state, pooled_output) + hidden_states - - return TFBaseModelOutputWithPoolingAndNoAttention( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, self.config.hidden_sizes[-1]]) - - -class TFConvNextV2PreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ConvNextV2Config - base_model_prefix = "convnextv2" - main_input_name = "pixel_values" - - -CONVNEXTV2_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`ConvNextV2Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -CONVNEXTV2_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]`, `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`ConvNextImageProcessor.__call__`] for details. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to `True`. -""" - - -@add_start_docstrings( - "The bare ConvNextV2 model outputting raw features without any specific head on top.", - CONVNEXTV2_START_DOCSTRING, -) -class TFConvNextV2Model(TFConvNextV2PreTrainedModel): - def __init__(self, config: ConvNextV2Config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.convnextv2 = TFConvNextV2MainLayer(config, name="convnextv2") - - @unpack_inputs - @add_start_docstrings_to_model_forward(CONVNEXTV2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPoolingAndNoAttention, - config_class=_CONFIG_FOR_DOC, - modality="vision", - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - def call( - self, - pixel_values: TFModelInputType | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPoolingAndNoAttention | tuple[tf.Tensor]: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - outputs = self.convnextv2( - pixel_values=pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return outputs[:] - - return TFBaseModelOutputWithPoolingAndNoAttention( - last_hidden_state=outputs.last_hidden_state, - pooler_output=outputs.pooler_output, - hidden_states=outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convnextv2", None) is not None: - with tf.name_scope(self.convnextv2.name): - self.convnextv2.build(None) - - -@add_start_docstrings( - """ - ConvNextV2 Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for - ImageNet. - """, - CONVNEXTV2_START_DOCSTRING, -) -class TFConvNextV2ForImageClassification(TFConvNextV2PreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: ConvNextV2Config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.convnextv2 = TFConvNextV2MainLayer(config, name="convnextv2") - - # Classifier head - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer=keras.initializers.Zeros(), - name="classifier", - ) - - @unpack_inputs - @add_start_docstrings_to_model_forward(CONVNEXTV2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=TFImageClassifierOutputWithNoAttention, - config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, - ) - def call( - self, - pixel_values: TFModelInputType | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFImageClassifierOutputWithNoAttention | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - outputs = self.convnextv2( - pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - pooled_output = outputs.pooler_output if return_dict else outputs[1] - - logits = self.classifier(pooled_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFImageClassifierOutputWithNoAttention( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convnextv2", None) is not None: - with tf.name_scope(self.convnextv2.name): - self.convnextv2.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_sizes[-1]]) - - -__all__ = ["TFConvNextV2ForImageClassification", "TFConvNextV2Model", "TFConvNextV2PreTrainedModel"] diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py deleted file mode 100644 index 1dce90147bd8..000000000000 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ /dev/null @@ -1,920 +0,0 @@ -# coding=utf-8 -# Copyright 2018 Salesforce and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 CTRL model.""" - -from __future__ import annotations - -import numpy as np -import tensorflow as tf - -from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast, TFSequenceClassifierOutput -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_ctrl import CTRLConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "Salesforce/ctrl" -_CONFIG_FOR_DOC = "CTRLConfig" - - -def angle_defn(pos, i, d_model_size): - angle_rates = 1 / np.power(10000, (2 * (i // 2)) / d_model_size) - return pos * angle_rates - - -def positional_encoding(position, d_model_size): - # create the sinusoidal pattern for the positional encoding - angle_rads = angle_defn(np.arange(position)[:, np.newaxis], np.arange(d_model_size)[np.newaxis, :], d_model_size) - - sines = np.sin(angle_rads[:, 0::2]) - cosines = np.cos(angle_rads[:, 1::2]) - pos_encoding = tf.convert_to_tensor(np.concatenate([sines, cosines], axis=-1)) - - return pos_encoding - - -def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None): - # calculate attention - matmul_qk = tf.matmul(q, k, transpose_b=True) - - dk = tf.cast(shape_list(k)[-1], dtype=matmul_qk.dtype) - scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) - - if mask is not None: - scaled_attention_logits += tf.cast(mask * -1e4, dtype=scaled_attention_logits.dtype) - - if attention_mask is not None: - # Apply the attention mask - attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype) - scaled_attention_logits = scaled_attention_logits + attention_mask - - attention_weights = stable_softmax(scaled_attention_logits, axis=-1) - - # Mask heads if we want to - if head_mask is not None: - attention_weights = attention_weights * head_mask - - output = tf.matmul(attention_weights, v) - - return output, attention_weights - - -class TFMultiHeadAttention(keras.layers.Layer): - def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs): - super().__init__(**kwargs) - self.num_heads = num_heads - self.d_model_size = d_model_size - self.output_attentions = output_attentions - - self.depth = int(d_model_size / self.num_heads) - - self.Wq = keras.layers.Dense(d_model_size, name="Wq") - self.Wk = keras.layers.Dense(d_model_size, name="Wk") - self.Wv = keras.layers.Dense(d_model_size, name="Wv") - - self.dense = keras.layers.Dense(d_model_size, name="dense") - - def split_into_heads(self, x, batch_size): - x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) - return tf.transpose(x, perm=[0, 2, 1, 3]) - - def call(self, v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False): - batch_size = shape_list(q)[0] - - q = self.Wq(q) - k = self.Wk(k) - v = self.Wv(v) - - q = self.split_into_heads(q, batch_size) - k = self.split_into_heads(k, batch_size) - v = self.split_into_heads(v, batch_size) - - if layer_past is not None: - past_key, past_value = tf.unstack(layer_past, axis=0) - k = tf.concat((past_key, k), axis=-2) - v = tf.concat((past_value, v), axis=-2) - - if use_cache: - present = tf.stack((k, v), axis=0) - else: - present = (None,) - - output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask) - scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3]) - attn = output[1] - original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size)) - output = self.dense(original_size_attention) - outputs = (output, present) - - if output_attentions: - outputs = outputs + (attn,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "Wq", None) is not None: - with tf.name_scope(self.Wq.name): - self.Wq.build([None, None, self.d_model_size]) - if getattr(self, "Wk", None) is not None: - with tf.name_scope(self.Wk.name): - self.Wk.build([None, None, self.d_model_size]) - if getattr(self, "Wv", None) is not None: - with tf.name_scope(self.Wv.name): - self.Wv.build([None, None, self.d_model_size]) - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.d_model_size]) - - -class TFPointWiseFeedForwardLayer(keras.layers.Layer): - def __init__(self, d_model_size, dff, **kwargs): - super().__init__(**kwargs) - - self.dense_0 = keras.layers.Dense(dff, activation="relu", name="0") - self.dense_2 = keras.layers.Dense(d_model_size, name="2") - self.d_model_size = d_model_size - self.dff = dff - - def call(self, inputs, trainable=False): - dense_0_output = self.dense_0(inputs) - dense_2_output = self.dense_2(dense_0_output) - - return dense_2_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense_0", None) is not None: - with tf.name_scope(self.dense_0.name): - self.dense_0.build([None, None, self.d_model_size]) - if getattr(self, "dense_2", None) is not None: - with tf.name_scope(self.dense_2.name): - self.dense_2.build([None, None, self.dff]) - - -class TFEncoderLayer(keras.layers.Layer): - def __init__( - self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs - ): - super().__init__(**kwargs) - - self.output_attentions = output_attentions - - self.multi_head_attention = TFMultiHeadAttention( - d_model_size, num_heads, output_attentions=self.output_attentions, name="multi_head_attention" - ) - self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn") - - self.layernorm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1") - self.layernorm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm2") - - self.dropout1 = keras.layers.Dropout(rate) - self.dropout2 = keras.layers.Dropout(rate) - self.d_model_size = d_model_size - - def call(self, x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False): - normed = self.layernorm1(x) - attn_outputs = self.multi_head_attention( - normed, - normed, - normed, - mask, - layer_past, - attention_mask, - head_mask, - use_cache, - output_attentions, - training=training, - ) - attn_output = attn_outputs[0] - attn_output = self.dropout1(attn_output, training=training) - out1 = x + attn_output - - out2 = self.layernorm2(out1) - ffn_output = self.ffn(out2) - ffn_output = self.dropout2(ffn_output, training=training) - out2 = out1 + ffn_output - - outputs = (out2,) + attn_outputs[1:] - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "multi_head_attention", None) is not None: - with tf.name_scope(self.multi_head_attention.name): - self.multi_head_attention.build(None) - if getattr(self, "ffn", None) is not None: - with tf.name_scope(self.ffn.name): - self.ffn.build(None) - if getattr(self, "layernorm1", None) is not None: - with tf.name_scope(self.layernorm1.name): - self.layernorm1.build([None, None, self.d_model_size]) - if getattr(self, "layernorm2", None) is not None: - with tf.name_scope(self.layernorm2.name): - self.layernorm2.build([None, None, self.d_model_size]) - - -@keras_serializable -class TFCTRLMainLayer(keras.layers.Layer): - config_class = CTRLConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - self.use_cache = config.use_cache - self.return_dict = config.use_return_dict - - self.d_model_size = config.n_embd - self.num_layers = config.n_layer - - self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size) - - self.w = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.n_embd, - embeddings_initializer=get_initializer(config.initializer_range), - name="w", - ) - - self.dropout = keras.layers.Dropout(config.embd_pdrop) - self.h = [ - TFEncoderLayer( - config.n_embd, - config.n_head, - config.dff, - config.resid_pdrop, - config.layer_norm_epsilon, - self.output_attentions, - name=f"h_._{i}", - ) - for i in range(config.n_layer) - ] - self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm") - - def get_input_embeddings(self): - return self.w - - def set_input_embeddings(self, new_embeddings): - self.w = new_embeddings - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutputWithPast: - # If using past key value states, only the last tokens - # should be given as an input - if past_key_values is not None: - if input_ids is not None: - input_ids = input_ids[:, -1:] - if inputs_embeds is not None: - inputs_embeds = inputs_embeds[:, -1:] - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1:] - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if past_key_values is None: - past_length = 0 - past_key_values = [None] * len(self.h) - else: - past_length = shape_list(past_key_values[0][0])[-2] - if position_ids is None: - position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0) - position_ids = tf.tile(position_ids, [input_shape[0], 1]) - - # Attention mask. - if attention_mask is not None: - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1] + past_length)) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - - one_cst = tf.constant(1.0) - ten_thousand_cst = tf.constant(-10000.0) - attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype) - attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), ten_thousand_cst) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.num_layers - - if token_type_ids is not None: - token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) - token_type_embeds = self.w(token_type_ids) - token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype)) - else: - token_type_embeds = tf.constant(0.0) - position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.w.input_dim) - inputs_embeds = self.w(input_ids) - seq_len = input_shape[-1] - mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) - - inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, inputs_embeds.dtype)) - - pos_embeds = tf.gather(self.pos_encoding, position_ids) - pos_embeds = tf.cast(pos_embeds, dtype=token_type_embeds.dtype) - hidden_states = inputs_embeds + pos_embeds + token_type_embeds - - hidden_states = self.dropout(hidden_states, training=training) - - output_shape = input_shape + [shape_list(hidden_states)[-1]] - presents = () if use_cache else None - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) - outputs = h( - hidden_states, - mask, - layer_past, - attention_mask, - head_mask[i], - use_cache, - output_attentions, - training=training, - ) - hidden_states, present = outputs[:2] - - if use_cache: - presents = presents + (present,) - - if output_attentions: - all_attentions = all_attentions + (outputs[2],) - - hidden_states = self.layernorm(hidden_states) - hidden_states = tf.reshape(hidden_states, output_shape) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if output_attentions: - # let the number of heads free (-1) so we can extract attention even after head pruning - attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] - all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "w", None) is not None: - with tf.name_scope(self.w.name): - self.w.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, self.config.n_embd]) - if getattr(self, "h", None) is not None: - for layer in self.h: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFCTRLPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = CTRLConfig - base_model_prefix = "transformer" - - -CTRL_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`CTRLConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -CTRL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past` is `None` else `past[0].shape[-2]` (`sequence_length` of - input past key value states). - - Indices of input sequence tokens in the vocabulary. - - If `past` is used, only input IDs that do not have their past calculated should be passed as `input_ids`. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - past (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see - `past` output below). Can be used to speed up sequential decoding. The token ids which have their past - given to this model should not be passed as input ids as they have already been computed. - attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, input_ids_length)`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, input_ids_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, input_ids_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past` key value states are returned and can be used to speed up decoding (see `past`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.", - CTRL_START_DOCSTRING, -) -class TFCTRLModel(TFCTRLPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFCTRLMainLayer(config, name="transformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPast, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutputWithPast: - outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -class TFCTRLBiasLayer(keras.layers.Layer): - """ - Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, - so all weights have to be registered in a layer. - """ - - def __init__(self, shape, initializer, trainable, name, **kwargs): - super().__init__(name=name, **kwargs) - self.shape = shape - self.initializer = initializer - self.trainable = trainable - - def build(self, input_shape): - self.bias = self.add_weight( - name="bias", shape=self.shape, initializer=self.initializer, trainable=self.trainable - ) - super().build(input_shape) - - def call(self, x): - return x + self.bias - - -@add_start_docstrings( - """ - The CTRL Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - CTRL_START_DOCSTRING, -) -class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFCTRLMainLayer(config, name="transformer") - self.bias_layer = TFCTRLBiasLayer( - name="lm_head", shape=[1, config.vocab_size], initializer="zeros", trainable=True - ) - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - def get_bias(self): - return {"lm_head.bias": self.bias_layer.bias} - - def set_bias(self, value): - # Replaces the existing layers containing bias for correct (de)serialization. - vocab_size = value["lm_head.bias"].shape[-1] - self.bias_layer = TFCTRLBiasLayer( - name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=True - ) - self.bias_layer.build(None) - self.bias_layer.bias.assign(value["lm_head.bias"]) - - # Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids") - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - inputs = tf.expand_dims(inputs[:, -1], -1) - if token_type_ids is not None: - token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1) - - position_ids = kwargs.get("position_ids") - attention_mask = kwargs.get("attention_mask") - - if attention_mask is not None and position_ids is None: - position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) - if past_key_values: - position_ids = tf.expand_dims(position_ids[:, -1], -1) - - return { - "input_ids": inputs, - "attention_mask": attention_mask, - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "token_type_ids": token_type_ids, - } - - @unpack_inputs - @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutputWithPast, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFCausalLMOutputWithPast: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - logits = tf.matmul(hidden_states, self.transformer.w.weights, transpose_b=True) - logits = self.bias_layer(logits) - - loss = None - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels, shifted_logits) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "bias_layer", None) is not None: - with tf.name_scope(self.bias_layer.name): - self.bias_layer.build(None) - - -@add_start_docstrings( - """ - The CTRL Model transformer with a sequence classification head on top (linear layer). - - [`TFCTRLForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-1, GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - CTRL_START_DOCSTRING, -) -class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.classifier = keras.layers.Dense( - config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - use_bias=False, - ) - self.transformer = TFCTRLMainLayer(config, name="transformer") - self.config = config - - def get_output_embeddings(self): - # Remove after transformers v4.32. Fix this model's `test_model_common_attributes` test too. - logger.warning( - "Sequence classification models do not have output embeddings. `.get_output_embeddings` will be removed " - "in transformers v4.32." - ) - return self.transformer.w - - @unpack_inputs - @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFSequenceClassifierOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - - transformer_outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - logits = self.classifier(hidden_states) - logits_shape = shape_list(logits) - batch_size = logits_shape[0] - - if self.config.pad_token_id is None: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - else: - if input_ids is not None: - token_indices = tf.range(shape_list(input_ids)[-1]) - non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype) - last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1) - else: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - loss = None - - pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1) - - if labels is not None: - if self.config.pad_token_id is None and logits_shape[0] != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - - loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels])) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=pooled_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.n_embd]) - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -__all__ = ["TFCTRLForSequenceClassification", "TFCTRLLMHeadModel", "TFCTRLModel", "TFCTRLPreTrainedModel"] diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py deleted file mode 100644 index 9239e1918eec..000000000000 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ /dev/null @@ -1,1095 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 Cvt model.""" - -from __future__ import annotations - -import collections.abc -from dataclasses import dataclass - -import tensorflow as tf - -from ...modeling_tf_outputs import TFImageClassifierOutputWithNoAttention -from ...modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_cvt import CvtConfig - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "CvtConfig" - - -@dataclass -class TFBaseModelOutputWithCLSToken(ModelOutput): - """ - Base class for model's outputs. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - cls_token_value (`tf.Tensor` of shape `(batch_size, 1, hidden_size)`): - Classification token at the output of the last layer of the model. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus - the initial embedding outputs. - """ - - last_hidden_state: tf.Tensor | None = None - cls_token_value: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - - -class TFCvtDropPath(keras.layers.Layer): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - References: - (1) github.com:rwightman/pytorch-image-models - """ - - def __init__(self, drop_prob: float, **kwargs): - super().__init__(**kwargs) - self.drop_prob = drop_prob - - def call(self, x: tf.Tensor, training=None): - if self.drop_prob == 0.0 or not training: - return x - keep_prob = 1 - self.drop_prob - shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) - random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=self.compute_dtype) - random_tensor = tf.floor(random_tensor) - return (x / keep_prob) * random_tensor - - -class TFCvtEmbeddings(keras.layers.Layer): - """Construct the Convolutional Token Embeddings.""" - - def __init__( - self, - config: CvtConfig, - patch_size: int, - num_channels: int, - embed_dim: int, - stride: int, - padding: int, - dropout_rate: float, - **kwargs, - ): - super().__init__(**kwargs) - self.convolution_embeddings = TFCvtConvEmbeddings( - config, - patch_size=patch_size, - num_channels=num_channels, - embed_dim=embed_dim, - stride=stride, - padding=padding, - name="convolution_embeddings", - ) - self.dropout = keras.layers.Dropout(dropout_rate) - - def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_state = self.convolution_embeddings(pixel_values) - hidden_state = self.dropout(hidden_state, training=training) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convolution_embeddings", None) is not None: - with tf.name_scope(self.convolution_embeddings.name): - self.convolution_embeddings.build(None) - - -class TFCvtConvEmbeddings(keras.layers.Layer): - """Image to Convolution Embeddings. This convolutional operation aims to model local spatial contexts.""" - - def __init__( - self, - config: CvtConfig, - patch_size: int, - num_channels: int, - embed_dim: int, - stride: int, - padding: int, - **kwargs, - ): - super().__init__(**kwargs) - self.padding = keras.layers.ZeroPadding2D(padding=padding) - self.patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - self.projection = keras.layers.Conv2D( - filters=embed_dim, - kernel_size=patch_size, - strides=stride, - padding="valid", - data_format="channels_last", - kernel_initializer=get_initializer(config.initializer_range), - name="projection", - ) - # Using the same default epsilon as PyTorch - self.normalization = keras.layers.LayerNormalization(epsilon=1e-5, name="normalization") - self.num_channels = num_channels - self.embed_dim = embed_dim - - def call(self, pixel_values: tf.Tensor) -> tf.Tensor: - if isinstance(pixel_values, dict): - pixel_values = pixel_values["pixel_values"] - - pixel_values = self.projection(self.padding(pixel_values)) - - # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels" - batch_size, height, width, num_channels = shape_list(pixel_values) - hidden_size = height * width - pixel_values = tf.reshape(pixel_values, shape=(batch_size, hidden_size, num_channels)) - pixel_values = self.normalization(pixel_values) - - # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels" - pixel_values = tf.reshape(pixel_values, shape=(batch_size, height, width, num_channels)) - return pixel_values - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, None, self.num_channels]) - if getattr(self, "normalization", None) is not None: - with tf.name_scope(self.normalization.name): - self.normalization.build([None, None, self.embed_dim]) - - -class TFCvtSelfAttentionConvProjection(keras.layers.Layer): - """Convolutional projection layer.""" - - def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: int, padding: int, **kwargs): - super().__init__(**kwargs) - self.padding = keras.layers.ZeroPadding2D(padding=padding) - self.convolution = keras.layers.Conv2D( - filters=embed_dim, - kernel_size=kernel_size, - kernel_initializer=get_initializer(config.initializer_range), - padding="valid", - strides=stride, - use_bias=False, - name="convolution", - groups=embed_dim, - ) - # Using the same default epsilon as PyTorch, TF uses (1 - pytorch momentum) - self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") - self.embed_dim = embed_dim - - def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_state = self.convolution(self.padding(hidden_state)) - hidden_state = self.normalization(hidden_state, training=training) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convolution", None) is not None: - with tf.name_scope(self.convolution.name): - self.convolution.build([None, None, None, self.embed_dim]) - if getattr(self, "normalization", None) is not None: - with tf.name_scope(self.normalization.name): - self.normalization.build([None, None, None, self.embed_dim]) - - -class TFCvtSelfAttentionLinearProjection(keras.layers.Layer): - """Linear projection layer used to flatten tokens into 1D.""" - - def call(self, hidden_state: tf.Tensor) -> tf.Tensor: - # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels" - batch_size, height, width, num_channels = shape_list(hidden_state) - hidden_size = height * width - hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels)) - return hidden_state - - -class TFCvtSelfAttentionProjection(keras.layers.Layer): - """Convolutional Projection for Attention.""" - - def __init__( - self, - config: CvtConfig, - embed_dim: int, - kernel_size: int, - stride: int, - padding: int, - projection_method: str = "dw_bn", - **kwargs, - ): - super().__init__(**kwargs) - if projection_method == "dw_bn": - self.convolution_projection = TFCvtSelfAttentionConvProjection( - config, embed_dim, kernel_size, stride, padding, name="convolution_projection" - ) - self.linear_projection = TFCvtSelfAttentionLinearProjection() - - def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_state = self.convolution_projection(hidden_state, training=training) - hidden_state = self.linear_projection(hidden_state) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convolution_projection", None) is not None: - with tf.name_scope(self.convolution_projection.name): - self.convolution_projection.build(None) - - -class TFCvtSelfAttention(keras.layers.Layer): - """ - Self-attention layer. A depth-wise separable convolution operation (Convolutional Projection), is applied for - query, key, and value embeddings. - """ - - def __init__( - self, - config: CvtConfig, - num_heads: int, - embed_dim: int, - kernel_size: int, - stride_q: int, - stride_kv: int, - padding_q: int, - padding_kv: int, - qkv_projection_method: str, - qkv_bias: bool, - attention_drop_rate: float, - with_cls_token: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.scale = embed_dim**-0.5 - self.with_cls_token = with_cls_token - self.embed_dim = embed_dim - self.num_heads = num_heads - - self.convolution_projection_query = TFCvtSelfAttentionProjection( - config, - embed_dim, - kernel_size, - stride_q, - padding_q, - projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method, - name="convolution_projection_query", - ) - self.convolution_projection_key = TFCvtSelfAttentionProjection( - config, - embed_dim, - kernel_size, - stride_kv, - padding_kv, - projection_method=qkv_projection_method, - name="convolution_projection_key", - ) - self.convolution_projection_value = TFCvtSelfAttentionProjection( - config, - embed_dim, - kernel_size, - stride_kv, - padding_kv, - projection_method=qkv_projection_method, - name="convolution_projection_value", - ) - - self.projection_query = keras.layers.Dense( - units=embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - use_bias=qkv_bias, - bias_initializer="zeros", - name="projection_query", - ) - self.projection_key = keras.layers.Dense( - units=embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - use_bias=qkv_bias, - bias_initializer="zeros", - name="projection_key", - ) - self.projection_value = keras.layers.Dense( - units=embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - use_bias=qkv_bias, - bias_initializer="zeros", - name="projection_value", - ) - self.dropout = keras.layers.Dropout(attention_drop_rate) - - def rearrange_for_multi_head_attention(self, hidden_state: tf.Tensor) -> tf.Tensor: - batch_size, hidden_size, _ = shape_list(hidden_state) - head_dim = self.embed_dim // self.num_heads - hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, self.num_heads, head_dim)) - hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1, 3)) - return hidden_state - - def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor: - if self.with_cls_token: - cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1) - - # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels" - batch_size, hidden_size, num_channels = shape_list(hidden_state) - hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels)) - - key = self.convolution_projection_key(hidden_state, training=training) - query = self.convolution_projection_query(hidden_state, training=training) - value = self.convolution_projection_value(hidden_state, training=training) - - if self.with_cls_token: - query = tf.concat((cls_token, query), axis=1) - key = tf.concat((cls_token, key), axis=1) - value = tf.concat((cls_token, value), axis=1) - - head_dim = self.embed_dim // self.num_heads - - query = self.rearrange_for_multi_head_attention(self.projection_query(query)) - key = self.rearrange_for_multi_head_attention(self.projection_key(key)) - value = self.rearrange_for_multi_head_attention(self.projection_value(value)) - - attention_score = tf.matmul(query, key, transpose_b=True) * self.scale - attention_probs = stable_softmax(logits=attention_score, axis=-1) - attention_probs = self.dropout(attention_probs, training=training) - - context = tf.matmul(attention_probs, value) - # "batch_size, num_heads, hidden_size, head_dim -> batch_size, hidden_size, (num_heads*head_dim)" - _, _, hidden_size, _ = shape_list(context) - context = tf.transpose(context, perm=(0, 2, 1, 3)) - context = tf.reshape(context, (batch_size, hidden_size, self.num_heads * head_dim)) - return context - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convolution_projection_query", None) is not None: - with tf.name_scope(self.convolution_projection_query.name): - self.convolution_projection_query.build(None) - if getattr(self, "convolution_projection_key", None) is not None: - with tf.name_scope(self.convolution_projection_key.name): - self.convolution_projection_key.build(None) - if getattr(self, "convolution_projection_value", None) is not None: - with tf.name_scope(self.convolution_projection_value.name): - self.convolution_projection_value.build(None) - if getattr(self, "projection_query", None) is not None: - with tf.name_scope(self.projection_query.name): - self.projection_query.build([None, None, self.embed_dim]) - if getattr(self, "projection_key", None) is not None: - with tf.name_scope(self.projection_key.name): - self.projection_key.build([None, None, self.embed_dim]) - if getattr(self, "projection_value", None) is not None: - with tf.name_scope(self.projection_value.name): - self.projection_value.build([None, None, self.embed_dim]) - - -class TFCvtSelfOutput(keras.layers.Layer): - """Output of the Attention layer .""" - - def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: float, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(drop_rate) - self.embed_dim = embed_dim - - def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_state = self.dense(inputs=hidden_state) - hidden_state = self.dropout(inputs=hidden_state, training=training) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.embed_dim]) - - -class TFCvtAttention(keras.layers.Layer): - """Attention layer. First chunk of the convolutional transformer block.""" - - def __init__( - self, - config: CvtConfig, - num_heads: int, - embed_dim: int, - kernel_size: int, - stride_q: int, - stride_kv: int, - padding_q: int, - padding_kv: int, - qkv_projection_method: str, - qkv_bias: bool, - attention_drop_rate: float, - drop_rate: float, - with_cls_token: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.attention = TFCvtSelfAttention( - config, - num_heads, - embed_dim, - kernel_size, - stride_q, - stride_kv, - padding_q, - padding_kv, - qkv_projection_method, - qkv_bias, - attention_drop_rate, - with_cls_token, - name="attention", - ) - self.dense_output = TFCvtSelfOutput(config, embed_dim, drop_rate, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False): - self_output = self.attention(hidden_state, height, width, training=training) - attention_output = self.dense_output(self_output, training=training) - return attention_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -class TFCvtIntermediate(keras.layers.Layer): - """Intermediate dense layer. Second chunk of the convolutional transformer block.""" - - def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - units=int(embed_dim * mlp_ratio), - kernel_initializer=get_initializer(config.initializer_range), - activation="gelu", - name="dense", - ) - self.embed_dim = embed_dim - - def call(self, hidden_state: tf.Tensor) -> tf.Tensor: - hidden_state = self.dense(hidden_state) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.embed_dim]) - - -class TFCvtOutput(keras.layers.Layer): - """ - Output of the Convolutional Transformer Block (last chunk). It consists of a MLP and a residual connection. - """ - - def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, drop_rate: int, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(drop_rate) - self.embed_dim = embed_dim - self.mlp_ratio = mlp_ratio - - def call(self, hidden_state: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_state = self.dense(inputs=hidden_state) - hidden_state = self.dropout(inputs=hidden_state, training=training) - hidden_state = hidden_state + input_tensor - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, int(self.embed_dim * self.mlp_ratio)]) - - -class TFCvtLayer(keras.layers.Layer): - """ - Convolutional Transformer Block composed by attention layers, normalization and multi-layer perceptrons (mlps). It - consists of 3 chunks : an attention layer, an intermediate dense layer and an output layer. This corresponds to the - `Block` class in the original implementation. - """ - - def __init__( - self, - config: CvtConfig, - num_heads: int, - embed_dim: int, - kernel_size: int, - stride_q: int, - stride_kv: int, - padding_q: int, - padding_kv: int, - qkv_projection_method: str, - qkv_bias: bool, - attention_drop_rate: float, - drop_rate: float, - mlp_ratio: float, - drop_path_rate: float, - with_cls_token: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.attention = TFCvtAttention( - config, - num_heads, - embed_dim, - kernel_size, - stride_q, - stride_kv, - padding_q, - padding_kv, - qkv_projection_method, - qkv_bias, - attention_drop_rate, - drop_rate, - with_cls_token, - name="attention", - ) - self.intermediate = TFCvtIntermediate(config, embed_dim, mlp_ratio, name="intermediate") - self.dense_output = TFCvtOutput(config, embed_dim, mlp_ratio, drop_rate, name="output") - # Using `layers.Activation` instead of `tf.identity` to better control `training` behaviour. - self.drop_path = ( - TFCvtDropPath(drop_path_rate, name="drop_path") - if drop_path_rate > 0.0 - else keras.layers.Activation("linear", name="drop_path") - ) - # Using the same default epsilon as PyTorch - self.layernorm_before = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_before") - self.layernorm_after = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_after") - self.embed_dim = embed_dim - - def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor: - # in Cvt, layernorm is applied before self-attention - attention_output = self.attention(self.layernorm_before(hidden_state), height, width, training=training) - attention_output = self.drop_path(attention_output, training=training) - - # first residual connection - hidden_state = attention_output + hidden_state - - # in Cvt, layernorm is also applied after self-attention - layer_output = self.layernorm_after(hidden_state) - layer_output = self.intermediate(layer_output) - - # second residual connection is done here - layer_output = self.dense_output(layer_output, hidden_state) - layer_output = self.drop_path(layer_output, training=training) - return layer_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - if getattr(self, "drop_path", None) is not None: - with tf.name_scope(self.drop_path.name): - self.drop_path.build(None) - if getattr(self, "layernorm_before", None) is not None: - with tf.name_scope(self.layernorm_before.name): - self.layernorm_before.build([None, None, self.embed_dim]) - if getattr(self, "layernorm_after", None) is not None: - with tf.name_scope(self.layernorm_after.name): - self.layernorm_after.build([None, None, self.embed_dim]) - - -class TFCvtStage(keras.layers.Layer): - """ - Cvt stage (encoder block). Each stage has 2 parts : - - (1) A Convolutional Token Embedding layer - - (2) A Convolutional Transformer Block (layer). - The classification token is added only in the last stage. - - Args: - config ([`CvtConfig`]): Model configuration class. - stage (`int`): Stage number. - """ - - def __init__(self, config: CvtConfig, stage: int, **kwargs): - super().__init__(**kwargs) - self.config = config - self.stage = stage - if self.config.cls_token[self.stage]: - self.cls_token = self.add_weight( - shape=(1, 1, self.config.embed_dim[-1]), - initializer=get_initializer(self.config.initializer_range), - trainable=True, - name="cvt.encoder.stages.2.cls_token", - ) - - self.embedding = TFCvtEmbeddings( - self.config, - patch_size=config.patch_sizes[self.stage], - num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1], - stride=config.patch_stride[self.stage], - embed_dim=config.embed_dim[self.stage], - padding=config.patch_padding[self.stage], - dropout_rate=config.drop_rate[self.stage], - name="embedding", - ) - - drop_path_rates = tf.linspace(0.0, config.drop_path_rate[self.stage], config.depth[stage]) - drop_path_rates = [x.numpy().item() for x in drop_path_rates] - self.layers = [ - TFCvtLayer( - config, - num_heads=config.num_heads[self.stage], - embed_dim=config.embed_dim[self.stage], - kernel_size=config.kernel_qkv[self.stage], - stride_q=config.stride_q[self.stage], - stride_kv=config.stride_kv[self.stage], - padding_q=config.padding_q[self.stage], - padding_kv=config.padding_kv[self.stage], - qkv_projection_method=config.qkv_projection_method[self.stage], - qkv_bias=config.qkv_bias[self.stage], - attention_drop_rate=config.attention_drop_rate[self.stage], - drop_rate=config.drop_rate[self.stage], - mlp_ratio=config.mlp_ratio[self.stage], - drop_path_rate=drop_path_rates[self.stage], - with_cls_token=config.cls_token[self.stage], - name=f"layers.{j}", - ) - for j in range(config.depth[self.stage]) - ] - - def call(self, hidden_state: tf.Tensor, training: bool = False): - cls_token = None - hidden_state = self.embedding(hidden_state, training) - - # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels" - batch_size, height, width, num_channels = shape_list(hidden_state) - hidden_size = height * width - hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels)) - - if self.config.cls_token[self.stage]: - cls_token = tf.repeat(self.cls_token, repeats=batch_size, axis=0) - hidden_state = tf.concat((cls_token, hidden_state), axis=1) - - for layer in self.layers: - layer_outputs = layer(hidden_state, height, width, training=training) - hidden_state = layer_outputs - - if self.config.cls_token[self.stage]: - cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1) - - # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels" - hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels)) - return hidden_state, cls_token - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embedding", None) is not None: - with tf.name_scope(self.embedding.name): - self.embedding.build(None) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFCvtEncoder(keras.layers.Layer): - """ - Convolutional Vision Transformer encoder. CVT has 3 stages of encoder blocks with their respective number of layers - (depth) being 1, 2 and 10. - - Args: - config ([`CvtConfig`]): Model configuration class. - """ - - config_class = CvtConfig - - def __init__(self, config: CvtConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.stages = [ - TFCvtStage(config, stage_idx, name=f"stages.{stage_idx}") for stage_idx in range(len(config.depth)) - ] - - def call( - self, - pixel_values: TFModelInputType, - output_hidden_states: bool | None = False, - return_dict: bool | None = True, - training: bool | None = False, - ) -> TFBaseModelOutputWithCLSToken | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - hidden_state = pixel_values - # When running on CPU, `keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width) - # as input format. So change the input format to (batch_size, height, width, num_channels). - hidden_state = tf.transpose(hidden_state, perm=(0, 2, 3, 1)) - - cls_token = None - for _, (stage_module) in enumerate(self.stages): - hidden_state, cls_token = stage_module(hidden_state, training=training) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_state,) - - # Change back to (batch_size, num_channels, height, width) format to have uniformity in the modules - hidden_state = tf.transpose(hidden_state, perm=(0, 3, 1, 2)) - if output_hidden_states: - all_hidden_states = tuple(tf.transpose(hs, perm=(0, 3, 1, 2)) for hs in all_hidden_states) - - if not return_dict: - return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None) - - return TFBaseModelOutputWithCLSToken( - last_hidden_state=hidden_state, - cls_token_value=cls_token, - hidden_states=all_hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "stages", None) is not None: - for layer in self.stages: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFCvtMainLayer(keras.layers.Layer): - """Construct the Cvt model.""" - - config_class = CvtConfig - - def __init__(self, config: CvtConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.encoder = TFCvtEncoder(config, name="encoder") - - @unpack_inputs - def call( - self, - pixel_values: TFModelInputType | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithCLSToken | tuple[tf.Tensor]: - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - encoder_outputs = self.encoder( - pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - - if not return_dict: - return (sequence_output,) + encoder_outputs[1:] - - return TFBaseModelOutputWithCLSToken( - last_hidden_state=sequence_output, - cls_token_value=encoder_outputs.cls_token_value, - hidden_states=encoder_outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - - -class TFCvtPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = CvtConfig - base_model_prefix = "cvt" - main_input_name = "pixel_values" - - -TFCVT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TF 2.0 models accepts two formats as inputs: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional arguments. - - This second option is useful when using [`keras.Model.fit`] method which currently requires having all the - tensors in the first argument of the model call function: `model(inputs)`. - - - - Args: - config ([`CvtConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -TFCVT_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CvtImageProcessor.__call__`] - for details. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.", - TFCVT_START_DOCSTRING, -) -class TFCvtModel(TFCvtPreTrainedModel): - def __init__(self, config: CvtConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.cvt = TFCvtMainLayer(config, name="cvt") - - @unpack_inputs - @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBaseModelOutputWithCLSToken, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: tf.Tensor | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithCLSToken | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, TFCvtModel - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13") - >>> model = TFCvtModel.from_pretrained("microsoft/cvt-13") - - >>> inputs = image_processor(images=image, return_tensors="tf") - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - outputs = self.cvt( - pixel_values=pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return (outputs[0],) + outputs[1:] - - return TFBaseModelOutputWithCLSToken( - last_hidden_state=outputs.last_hidden_state, - cls_token_value=outputs.cls_token_value, - hidden_states=outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "cvt", None) is not None: - with tf.name_scope(self.cvt.name): - self.cvt.build(None) - - -@add_start_docstrings( - """ - Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of - the [CLS] token) e.g. for ImageNet. - """, - TFCVT_START_DOCSTRING, -) -class TFCvtForImageClassification(TFCvtPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: CvtConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.cvt = TFCvtMainLayer(config, name="cvt") - # Using same default epsilon as in the original implementation. - self.layernorm = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm") - - # Classifier head - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - use_bias=True, - bias_initializer="zeros", - name="classifier", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: tf.Tensor | None = None, - labels: tf.Tensor | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFImageClassifierOutputWithNoAttention | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, TFCvtForImageClassification - >>> import tensorflow as tf - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13") - >>> model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13") - - >>> inputs = image_processor(images=image, return_tensors="tf") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] - >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)]) - ```""" - - outputs = self.cvt( - pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - cls_token = outputs[1] - if self.config.cls_token[-1]: - sequence_output = self.layernorm(cls_token) - else: - # rearrange "batch_size, num_channels, height, width -> batch_size, (height*width), num_channels" - batch_size, num_channels, height, width = shape_list(sequence_output) - sequence_output = tf.reshape(sequence_output, shape=(batch_size, num_channels, height * width)) - sequence_output = tf.transpose(sequence_output, perm=(0, 2, 1)) - sequence_output = self.layernorm(sequence_output) - - sequence_output_mean = tf.reduce_mean(sequence_output, axis=1) - logits = self.classifier(sequence_output_mean) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "cvt", None) is not None: - with tf.name_scope(self.cvt.name): - self.cvt.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, self.config.embed_dim[-1]]) - if getattr(self, "classifier", None) is not None: - if hasattr(self.classifier, "name"): - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.embed_dim[-1]]) - - -__all__ = ["TFCvtForImageClassification", "TFCvtModel", "TFCvtPreTrainedModel"] diff --git a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py deleted file mode 100644 index 0fa0fe1f811e..000000000000 --- a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py +++ /dev/null @@ -1,1723 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 Data2Vec Vision model.""" - -from __future__ import annotations - -import collections.abc -import math -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, - TFSemanticSegmenterOutput, - TFSequenceClassifierOutput, -) -from ...modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_data2vec_vision import Data2VecVisionConfig - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "Data2VecVisionConfig" - -# Base docstring -_CHECKPOINT_FOR_DOC = "facebook/data2vec-vision-base" -_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] - -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k" -_IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote" - - -@dataclass -class TFData2VecVisionModelOutputWithPooling(TFBaseModelOutputWithPooling): - """ - Class for outputs of [`TFData2VecVisionModel`]. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): - Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if - *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token - will be returned. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: tf.Tensor | None = None - pooler_output: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -class TFData2VecVisionDropPath(keras.layers.Layer): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - References: - (1) github.com:rwightman/pytorch-image-models - """ - - def __init__(self, drop_path, **kwargs): - super().__init__(**kwargs) - self.drop_path = drop_path - - def call(self, x, training=None): - if training: - keep_prob = 1 - self.drop_path - shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) - random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) - random_tensor = tf.floor(random_tensor) - return (x / keep_prob) * random_tensor - return x - - -class TFData2VecVisionEmbeddings(keras.layers.Layer): - """ - Construct the CLS token, position and patch embeddings. Optionally, also the mask token. - - """ - - def __init__(self, config: Data2VecVisionConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - - self.patch_embeddings = TFData2VecVisionPatchEmbeddings(config, name="patch_embeddings") - self.num_patches = self.patch_embeddings.num_patches - self.config = config - - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - - def build(self, input_shape=None): - self.cls_token = self.add_weight( - shape=(1, 1, self.config.hidden_size), - initializer=tf.random_normal_initializer(stddev=self.config.initializer_range), - trainable=True, - name="cls_token", - ) - if self.config.use_mask_token: - self.mask_token = self.add_weight( - shape=(1, 1, self.config.hidden_size), - initializer=tf.random_normal_initializer(stddev=self.config.initializer_range), - trainable=True, - name="mask_token", - ) - else: - self.mask_token = None - - if self.config.use_absolute_position_embeddings: - self.position_embeddings = self.add_weight( - shape=(1, self.num_patches + 1, self.config.hidden_size), - initializer=tf.random_normal_initializer(stddev=self.config.initializer_range), - trainable=True, - name="position_embeddings", - ) - else: - self.position_embeddings = None - - if self.built: - return - self.built = True - if getattr(self, "patch_embeddings", None) is not None: - with tf.name_scope(self.patch_embeddings.name): - self.patch_embeddings.build(None) - - def call(self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None) -> tf.Tensor: - embeddings = self.patch_embeddings(pixel_values) - batch_size, seq_len, projection_dim = shape_list(embeddings) - - cls_tokens = tf.tile(self.cls_token, (batch_size, 1, 1)) - - if bool_masked_pos is not None: - mask_tokens = tf.broadcast_to(self.mask_token, (batch_size, seq_len, projection_dim)) - # replace the masked visual tokens by mask_tokens - w = bool_masked_pos[..., None] - w = tf.cast(w, mask_tokens.dtype) - # since TF doesn't support eager tensor assignment - embeddings = embeddings * (1 - w) + mask_tokens * w - - embeddings = tf.concat([cls_tokens, embeddings], axis=1) - if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings - embeddings = self.dropout(embeddings) - - return embeddings - - -class TFData2VecVisionPatchEmbeddings(keras.layers.Layer): - """ - Image to Patch Embedding. - """ - - def __init__(self, config: Data2VecVisionConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.hidden_size - - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) - self.image_size = image_size - self.patch_size = patch_size - self.num_patches = num_patches - self.patch_shape = patch_shape - self.num_channels = num_channels - - self.projection = keras.layers.Conv2D( - filters=hidden_size, - kernel_size=patch_size, - strides=patch_size, - padding="valid", - data_format="channels_last", - kernel_initializer="glorot_uniform", # following torch.nn.Linear - bias_initializer="zeros", - name="projection", - ) - - def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: - batch_size, num_channels, height, width = shape_list(pixel_values) - if tf.executing_eagerly(): - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the" - " configuration." - ) - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) - - # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - # shape = (batch_size, in_height, in_width, in_channels=num_channels) - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - - projection = self.projection(pixel_values) - - # Change the 2D spatial dimensions to a single temporal dimension. - # shape = (batch_size, num_patches, out_channels=embed_dim) - num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0]) - - return tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1)) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, None, self.num_channels]) - - -class TFData2VecVisionSelfAttention(keras.layers.Layer): - def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="key", - use_bias=False, - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - - if window_size: - self.relative_position_bias = TFData2VecVisionRelativePositionBias( - config, window_size=window_size, name="relative_position_bias" - ) - else: - self.relative_position_bias = None - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - relative_position_bias: TFData2VecVisionRelativePositionBias | None = None, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - mixed_key_layer = self.key(inputs=hidden_states) - mixed_value_layer = self.value(inputs=hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - attention_scores = attention_scores / self.sqrt_att_head_size - - # Add relative position bias if present. - if self.relative_position_bias is not None: - # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras - # might complain about `Layer.call()` not being invoked properly. In this case this input - # i.e., 0.0 is not going to be used in any calculations so we're safe. - attention_scores = attention_scores + self.relative_position_bias(0.0)[None, ...] - - # Add shared relative position bias if provided. - if relative_position_bias is not None: - attention_scores = attention_scores + relative_position_bias - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - if getattr(self, "relative_position_bias", None) is not None: - with tf.name_scope(self.relative_position_bias.name): - self.relative_position_bias.build(None) - - -class TFData2VecVisionSelfOutput(keras.layers.Layer): - """ - The residual connection is defined in TFData2VecVisionLayer instead of here (as is the case with other models), due - to the layernorm applied before each block. - """ - - def __init__(self, config: Data2VecVisionConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, gamma=None, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFData2VecVisionAttention(keras.layers.Layer): - def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs): - super().__init__(**kwargs) - - self.attention = TFData2VecVisionSelfAttention(config, window_size=window_size, name="attention") - self.dense_output = TFData2VecVisionSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - relative_position_bias: TFData2VecVisionRelativePositionBias | None = None, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.attention( - hidden_states=input_tensor, - head_mask=head_mask, - output_attentions=output_attentions, - relative_position_bias=relative_position_bias, - training=training, - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->Data2VecVision -class TFData2VecVisionIntermediate(keras.layers.Layer): - def __init__(self, config: Data2VecVisionConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFData2VecVisionOutput(keras.layers.Layer): - def __init__(self, config: Data2VecVisionConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - - -class TFData2VecVisionLayer(keras.layers.Layer): - """This corresponds to the Block class in the timm implementation.""" - - def __init__( - self, config: Data2VecVisionConfig, window_size: tuple | None = None, drop_path_rate: float = 0.0, **kwargs - ): - super().__init__(**kwargs) - self.config = config - - self.attention = TFData2VecVisionAttention(config, window_size=window_size, name="attention") - self.intermediate = TFData2VecVisionIntermediate(config, name="intermediate") - self.data2vec_output = TFData2VecVisionOutput(config, name="output") - - self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") - self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") - # Using `layers.Activation` instead of `tf.identity` to better control `training` - # behaviour. - self.drop_path = ( - TFData2VecVisionDropPath(drop_path_rate, name="drop_path") - if drop_path_rate > 0.0 - else keras.layers.Activation("linear", name="drop_path") - ) - self.init_values = config.layer_scale_init_value - - def build(self, input_shape: tf.TensorShape = None): - if self.init_values > 0: - self.lambda_1 = self.add_weight( - shape=(self.config.hidden_size), - initializer="ones", - trainable=True, - name="lambda_1", - ) - self.lambda_2 = self.add_weight( - shape=(self.config.hidden_size), - initializer="ones", - trainable=True, - name="lambda_2", - ) - self.lambda_1.assign(self.init_values * tf.ones(self.config.hidden_size)) - self.lambda_2.assign(self.init_values * tf.ones(self.config.hidden_size)) - else: - self.lambda_1, self.lambda_2 = None, None - - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "data2vec_output", None) is not None: - with tf.name_scope(self.data2vec_output.name): - self.data2vec_output.build(None) - if getattr(self, "layernorm_before", None) is not None: - with tf.name_scope(self.layernorm_before.name): - self.layernorm_before.build([None, None, self.config.hidden_size]) - if getattr(self, "layernorm_after", None) is not None: - with tf.name_scope(self.layernorm_after.name): - self.layernorm_after.build([None, None, self.config.hidden_size]) - if getattr(self, "drop_path", None) is not None: - with tf.name_scope(self.drop_path.name): - self.drop_path.build(None) - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - relative_position_bias: TFData2VecVisionRelativePositionBias | None = None, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_attention_outputs = self.attention( - # in Data2VecVision, layernorm is applied before self-attention - input_tensor=self.layernorm_before(inputs=hidden_states), - head_mask=head_mask, - output_attentions=output_attentions, - relative_position_bias=relative_position_bias, - training=training, - ) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - # apply lambda_1 if present - if self.lambda_1 is not None: - attention_output = self.lambda_1 * attention_output - - # first residual connection - hidden_states = self.drop_path(attention_output) + hidden_states - - # in Data2VecVision, layernorm is also applied after self-attention - layer_output = self.layernorm_after(hidden_states) - - layer_output = self.intermediate(layer_output) - layer_output = self.data2vec_output(layer_output) - - if self.lambda_2 is not None: - layer_output = self.lambda_2 * layer_output - - # second residual connection - layer_output = self.drop_path(layer_output) + hidden_states - - outputs = (layer_output,) + outputs - - return outputs - - -# Taken and modified from here: -# https://github.com/leondgarse/keras_cv_attention_models/blob/main/keras_cv_attention_models/beit/beit.py#L28 -class TFData2VecVisionRelativePositionBias(keras.layers.Layer): - def __init__(self, config: Data2VecVisionConfig, window_size: tuple, **kwargs) -> None: - super().__init__(**kwargs) - self.config = config - - self.window_size = window_size - # +3 for cls_token_pos_len - # window_size can be something like (14, 14) - self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 - - self.relative_position_index = self.get_position_index() - - def build(self, input_shape): - self.relative_position_bias_table = self.add_weight( - shape=(self.num_relative_distance, self.config.num_attention_heads), - initializer="zeros", - trainable=True, - name="relative_position_bias_table", - ) # [2*Wh-1 * 2*Ww-1, nH] - # cls to token & token 2 cls & cls to cls - - super().build(input_shape) - - def get_position_index(self): - # get pair-wise relative position index for each token inside the window - xx, yy = tf.meshgrid(range(self.window_size[0]), range(self.window_size[1])) - coords = tf.stack([yy, xx], axis=0) # [2, Wh, Ww] - coords_flatten = tf.reshape(coords, [2, -1]) # [2, Wh*Ww] - - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Wh*Ww, Wh*Ww] - relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0]) # [Wh*Ww, Wh*Ww, 2] - - xx = (relative_coords[:, :, 0] + self.window_size[0] - 1) * (2 * self.window_size[1] - 1) - yy = relative_coords[:, :, 1] + self.window_size[1] - 1 - relative_coords = tf.stack([xx, yy], axis=-1) - - relative_position_index = tf.reduce_sum(relative_coords, axis=-1) # [Wh*Ww, Wh*Ww] - - top = tf.ones((1, relative_position_index.shape[1]), dtype=relative_position_index.dtype) * ( - self.num_relative_distance - 3 - ) - left = tf.ones((relative_position_index.shape[0], 1), dtype=relative_position_index.dtype) * ( - self.num_relative_distance - 2 - ) - corner = tf.ones((1, 1), dtype=relative_position_index.dtype) * (self.num_relative_distance - 1) - - left_corner = tf.concat([corner, left], axis=0) - relative_position_index = tf.concat([top, relative_position_index], axis=0) - relative_position_index = tf.concat([left_corner, relative_position_index], axis=1) # [Wh*Ww + 1, Wh*Ww + 1] - return relative_position_index - - def call(self, inputs=None) -> tf.Tensor: - relative_position_bias = tf.gather(self.relative_position_bias_table, self.relative_position_index, axis=0) - return tf.transpose(relative_position_bias, [2, 0, 1]) - - -class TFData2VecVisionEncoder(keras.layers.Layer): - def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - if config.use_shared_relative_position_bias: - self.relative_position_bias = TFData2VecVisionRelativePositionBias( - config, window_size=window_size, name="relative_position_bias" - ) - else: - self.relative_position_bias = None - - # stochastic depth decay rule - dpr = list(tf.linspace(0.0, config.drop_path_rate, config.num_hidden_layers)) - self.layer = [ - TFData2VecVisionLayer( - config, - window_size=window_size if config.use_relative_position_bias else None, - drop_path_rate=dpr[i], - name=f"layer_._{i}", - ) - for i in range(config.num_hidden_layers) - ] - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor | None = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ) -> tuple | TFBaseModelOutput: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras - # might complain about `Layer.call()` not being invoked properly. In this case this input - # i.e., 0.0 is not going to be used in any calculations so we're safe. - relative_position_bias = ( - self.relative_position_bias(0.0) if self.relative_position_bias is not None else None - ) - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "relative_position_bias", None) is not None: - with tf.name_scope(self.relative_position_bias.name): - self.relative_position_bias.build(None) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFData2VecVisionMainLayer(keras.layers.Layer): - config_class = Data2VecVisionConfig - - def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = True, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.add_pooling_layer = add_pooling_layer - - self.embeddings = TFData2VecVisionEmbeddings(config, name="embeddings") - self.encoder = TFData2VecVisionEncoder( - config, window_size=self.embeddings.patch_embeddings.patch_shape, name="encoder" - ) - self.layernorm = ( - tf.identity - if config.use_mean_pooling - else keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - ) - - # We are setting the `data_format` like so because from here on we will revert to the - # NCHW output format - self.pooler = TFData2VecVisionPooler(config, name="pooler") if add_pooling_layer else None - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings.patch_embeddings - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - pixel_values: tf.Tensor | None = None, - bool_masked_pos: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple | TFData2VecVisionModelOutputWithPooling: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - embedding_output = self.embeddings(pixel_values, bool_masked_pos, training=training) - - encoder_outputs = self.encoder( - embedding_output, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output) - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) - return head_outputs + encoder_outputs[1:] - - return TFData2VecVisionModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "layernorm", None) is not None: - if hasattr(self.layernorm, "name"): - with tf.name_scope(self.layernorm.name): - self.layernorm.build((None, self.config.hidden_size)) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -class TFData2VecVisionPooler(keras.layers.Layer): - def __init__(self, config: Data2VecVisionConfig, **kwargs): - super().__init__(**kwargs) - self.layernorm = ( - keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - if config.use_mean_pooling - else None - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - if self.layernorm is not None: - # Mean pool the final hidden states of the patch tokens - patch_tokens = hidden_states[:, 1:, :] - pooled_output = self.layernorm(tf.reduce_mean(patch_tokens, axis=1)) - else: - # Pool by simply taking the final hidden state of the [CLS] token - pooled_output = hidden_states[:, 0] - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layernorm", None) is not None: - if hasattr(self.layernorm, "name"): - with tf.name_scope(self.layernorm.name): - self.layernorm.build((None, self.config.hidden_size)) - - -class TFData2VecVisionPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = Data2VecVisionConfig - base_model_prefix = "data2vec_vision" - main_input_name = "pixel_values" - _keys_to_ignore_on_load_unexpected = [r"relative_position_index"] - - -DATA2VEC_VISION_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.). - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`Data2VecVisionConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -DATA2VEC_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`BeitImageProcessor.__call__`] for details. - - head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - - return_dict (`bool`, *optional*): - Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare Data2VecVision Model transformer outputting raw hidden-states without any specific head on top.", - DATA2VEC_VISION_START_DOCSTRING, -) -class TFData2VecVisionModel(TFData2VecVisionPreTrainedModel): - def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.config = config - - self.data2vec_vision = TFData2VecVisionMainLayer( - config, add_pooling_layer=add_pooling_layer, name="data2vec_vision" - ) - - def get_input_embeddings(self): - return self.data2vec_vision.get_input_embeddings() - - @unpack_inputs - @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFData2VecVisionModelOutputWithPooling, - config_class=_CONFIG_FOR_DOC, - modality="vision", - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - def call( - self, - pixel_values: TFModelInputType | None = None, - bool_masked_pos: tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple | TFData2VecVisionModelOutputWithPooling: - r""" - bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`, *optional*): - Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). - """ - outputs = self.data2vec_vision( - pixel_values=pixel_values, - bool_masked_pos=bool_masked_pos, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "data2vec_vision", None) is not None: - with tf.name_scope(self.data2vec_vision.name): - self.data2vec_vision.build(None) - - -@add_start_docstrings( - """ - Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of - the final hidden states of the patch tokens) e.g. for ImageNet. - """, - DATA2VEC_VISION_START_DOCSTRING, -) -class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=True, name="data2vec_vision") - - # Classifier head - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, - ) - def call( - self, - pixel_values: TFModelInputType | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.data2vec_vision( - pixel_values=pixel_values, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - pooled_output = outputs.pooler_output if return_dict else outputs[1] - logits = self.classifier(pooled_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "data2vec_vision", None) is not None: - with tf.name_scope(self.data2vec_vision.name): - self.data2vec_vision.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -class TFData2VecVisionConvModule(keras.layers.Layer): - """ - A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution - layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). - - Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int | tuple[int, int], - padding: str = "valid", - bias: bool = False, - dilation: int | tuple[int, int] = 1, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.conv = keras.layers.Conv2D( - filters=out_channels, - kernel_size=kernel_size, - padding=padding, - use_bias=bias, - dilation_rate=dilation, - name="conv", - ) - self.bn = keras.layers.BatchNormalization(name="bn", momentum=0.9, epsilon=1e-5) - self.activation = tf.nn.relu - self.in_channels = in_channels - self.out_channels = out_channels - - def call(self, input: tf.Tensor) -> tf.Tensor: - output = self.conv(input) - output = self.bn(output) - output = self.activation(output) - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build([None, None, None, self.in_channels]) - if getattr(self, "bn", None) is not None: - with tf.name_scope(self.bn.name): - self.bn.build((None, None, None, self.out_channels)) - - -class TFAdaptiveAvgPool2D(keras.layers.Layer): - def __init__(self, output_dims: tuple[int, int], input_ordering: str = "NHWC", **kwargs): - super().__init__(**kwargs) - self.output_dims = output_dims - self.input_ordering = input_ordering - if input_ordering not in ("NCHW", "NHWC"): - raise ValueError("Unrecognized input_ordering, should be 'NCHW' or 'NHWC'!") - self.h_axis = input_ordering.index("H") - self.w_axis = input_ordering.index("W") - - def pseudo_1d_pool(self, inputs: tf.Tensor, h_pooling: bool): - # Figure out which axis we're pooling on - if h_pooling: - axis = self.h_axis - output_dim = self.output_dims[0] - else: - axis = self.w_axis - output_dim = self.output_dims[1] - input_dim = inputs.shape[axis] - - # Figure out the potential pooling windows - # This is the key idea - the torch op always uses only two - # consecutive pooling window sizes, like 3 and 4. Therefore, - # if we pool with both possible sizes, we simply need to gather - # the 'correct' pool at each position to reimplement the torch op. - small_window = math.ceil(input_dim / output_dim) - big_window = small_window + 1 - if h_pooling: - output_dim = self.output_dims[0] - small_window_shape = (small_window, 1) - big_window_shape = (big_window, 1) - else: - output_dim = self.output_dims[1] - small_window_shape = (1, small_window) - big_window_shape = (1, big_window) - - # For resizes to 1, or integer resizes, we can take quick shortcuts - if output_dim == input_dim: - return inputs - elif output_dim == 1: - return tf.reduce_mean(inputs, axis=axis, keepdims=True) - elif input_dim % output_dim == 0: - return tf.nn.avg_pool2d( - inputs, - ksize=small_window_shape, - strides=small_window_shape, - padding="VALID", - data_format=self.input_ordering, - ) - # When upscaling by an integer factor we can also take a quick shortcut - elif output_dim > input_dim and output_dim % input_dim == 0: - return tf.repeat(inputs, repeats=output_dim // input_dim, axis=axis) - - # For non-integer resizes, we pool with both possible window sizes and concatenate them - if output_dim < input_dim: - small_pool = tf.nn.avg_pool2d( - inputs, ksize=small_window_shape, strides=1, padding="VALID", data_format=self.input_ordering - ) - big_pool = tf.nn.avg_pool2d( - inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering - ) - both_pool = tf.concat([small_pool, big_pool], axis=axis) - else: - # When we're actually upscaling instead, then we build the pools a bit differently - small_pool = inputs - big_pool = tf.nn.avg_pool2d( - inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering - ) - both_pool = tf.concat([small_pool, big_pool], axis=axis) - - # We compute vectors of the start and end positions for each pooling window - # Each (start, end) pair here corresponds to a single output position - window_starts = tf.math.floor((tf.range(output_dim, dtype=tf.float32) * input_dim) / output_dim) - window_starts = tf.cast(window_starts, tf.int64) - window_ends = tf.math.ceil((tf.range(1, output_dim + 1, dtype=tf.float32) * input_dim) / output_dim) - window_ends = tf.cast(window_ends, tf.int64) - - # pool_selector is a boolean array of shape (output_dim,) where 1 indicates that output position - # has a big receptive field and 0 indicates that that output position has a small receptive field - pool_selector = tf.cast(window_ends - window_starts - small_window, tf.bool) - - # Since we concatenated the small and big pools, we need to do a bit of - # pointer arithmetic to get the indices of the big pools - small_indices = window_starts - big_indices = window_starts + small_pool.shape[axis] - - # Finally, we use the pool_selector to generate a list of indices, one per output position - gather_indices = tf.where(pool_selector, big_indices, small_indices) - - # Gathering from those indices yields the final, correct pooling - return tf.gather(both_pool, gather_indices, axis=axis) - - def call(self, inputs: tf.Tensor): - if self.input_ordering == "NHWC": - input_shape = inputs.shape[1:3] - else: - input_shape = inputs.shape[2:] - - # We break the task down into each possible case - # Firstly, if we're resizing down to 1, it's just tf.reduce_mean - if self.output_dims[0] == self.output_dims[1] == 1: - if self.input_ordering == "NHWC": - reduce_dims = [1, 2] - else: - reduce_dims = [2, 3] - return tf.reduce_mean(inputs, axis=reduce_dims, keepdims=True) - # Secondly, if we're resizing by an integer factor on both dimensions, we can take a quick shortcut - elif input_shape[0] % self.output_dims[0] == 0 and input_shape[1] % self.output_dims[1] == 0: - h_resize = int(input_shape[0] // self.output_dims[0]) - w_resize = int(input_shape[1] // self.output_dims[1]) - return tf.nn.avg_pool2d( - inputs, - ksize=(h_resize, w_resize), - strides=(h_resize, w_resize), - padding="VALID", - data_format=self.input_ordering, - ) - else: - # Finally, if we can't take the shortcut, we do a 1D pool on each axis. pseudo_1d_pool will take a shortcut - # for dimensions where an integer resize is possible. It can also handle upscaling. - h_pooled = self.pseudo_1d_pool(inputs, h_pooling=True) - return self.pseudo_1d_pool(h_pooled, h_pooling=False) - - -class TFData2VecVisionPyramidPoolingModule(keras.layers.Layer): - """ - Pyramid Pooling Module (PPM) used in PSPNet. - - Args: - pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid - Module. - channels (int): Channels after modules, before conv_seg. - - Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. - """ - - def __init__(self, pool_scales: tuple[int, ...], in_channels: int, out_channels: int, **kwargs) -> None: - super().__init__(**kwargs) - self.pool_scales = pool_scales - self.in_channels = in_channels - self.out_channels = out_channels - - self.layer_list = [] - for idx, pool_scale in enumerate(pool_scales): - pool_scale = pool_scale if isinstance(pool_scale, collections.abc.Iterable) else (pool_scale, pool_scale) - self.layer_list.append( - [ - TFAdaptiveAvgPool2D(output_dims=pool_scale), - TFData2VecVisionConvModule( - in_channels=in_channels, out_channels=self.out_channels, kernel_size=1, name=f"{idx}.1" - ), - ] - ) - - def call(self, x: tf.Tensor) -> list[tf.Tensor]: - ppm_outs = [] - inputs = x - - for ppm in self.layer_list: - for layer_module in ppm: - ppm_out = layer_module(x) - x = ppm_out - - upsampled_ppm_out = tf.image.resize(ppm_out, size=shape_list(inputs)[1:-1], method="bilinear") - ppm_outs.append(upsampled_ppm_out) - return ppm_outs - - def build(self, input_shape=None): - for layer in self.layer_list: - for layer_module in layer: - with tf.name_scope(layer_module.name): - layer_module.build(None) - - -class TFData2VecVisionUperHead(keras.layers.Layer): - """ - Unified Perceptual Parsing for Scene Understanding. This head is the implementation of - [UPerNet](https://huggingface.co/papers/1807.10221). - - Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. - """ - - def __init__(self, config: Data2VecVisionConfig, **kwargs) -> None: - super().__init__(**kwargs) - - self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) - self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] - self.channels = config.hidden_size - self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier") - - # PSP Module - self.psp_modules = TFData2VecVisionPyramidPoolingModule( - self.pool_scales, self.in_channels[-1], self.channels, name="psp_modules" - ) - self.bottleneck = TFData2VecVisionConvModule( - self.in_channels[-1] + len(self.pool_scales) * self.channels, - self.channels, - kernel_size=3, - padding="same", - name="bottleneck", - ) - # FPN Module - self.lateral_convs = [] - self.fpn_convs = [] - for idx, in_channels in enumerate(self.in_channels[:-1]): # skip the top layer - l_conv = TFData2VecVisionConvModule( - in_channels, out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}" - ) - fpn_conv = TFData2VecVisionConvModule( - in_channels=self.channels, - out_channels=self.channels, - kernel_size=3, - padding="same", - name=f"fpn_convs.{idx}", - ) - self.lateral_convs.append(l_conv) - self.fpn_convs.append(fpn_conv) - - self.fpn_bottleneck = TFData2VecVisionConvModule( - in_channels=len(self.in_channels) * self.channels, - out_channels=self.channels, - kernel_size=3, - padding="same", - name="fpn_bottleneck", - ) - - def psp_forward(self, inputs): - x = inputs[-1] - psp_outs = [x] - psp_outs.extend(self.psp_modules(x)) - psp_outs = tf.concat(psp_outs, axis=-1) - output = self.bottleneck(psp_outs) - - return output - - def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor: - # build laterals - laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] - - laterals.append(self.psp_forward(encoder_hidden_states)) - - # build top-down path - used_backbone_levels = len(laterals) - for i in range(used_backbone_levels - 1, 0, -1): - prev_shape = shape_list(laterals[i - 1])[1:-1] - laterals[i - 1] = laterals[i - 1] + tf.image.resize(laterals[i], size=prev_shape, method="bilinear") - - # build outputs - fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)] - # append psp feature - fpn_outs.append(laterals[-1]) - - for i in range(used_backbone_levels - 1, 0, -1): - fpn_outs[i] = tf.image.resize(fpn_outs[i], size=shape_list(fpn_outs[0])[1:-1], method="bilinear") - fpn_outs = tf.concat(fpn_outs, axis=-1) - output = self.fpn_bottleneck(fpn_outs) - output = self.classifier(output) - - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, None, self.channels]) - if getattr(self, "psp_modules", None) is not None: - with tf.name_scope(self.psp_modules.name): - self.psp_modules.build(None) - if getattr(self, "bottleneck", None) is not None: - with tf.name_scope(self.bottleneck.name): - self.bottleneck.build(None) - if getattr(self, "fpn_bottleneck", None) is not None: - with tf.name_scope(self.fpn_bottleneck.name): - self.fpn_bottleneck.build(None) - for layer in self.lateral_convs: - with tf.name_scope(layer.name): - layer.build(None) - for layer in self.fpn_convs: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFData2VecVisionFCNHead(keras.layers.Layer): - """ - Fully Convolution Networks for Semantic Segmentation. This head is implemented from - [FCNNet](https://huggingface.co/papers/1411.4038). - - Args: - config (Data2VecVisionConfig): Configuration. - kernel_size (int): The kernel size for convs in the head. Default: 3. - dilation (int): The dilation rate for convs in the head. Default: 1. - - - Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. - """ - - def __init__( - self, - config: Data2VecVisionConfig, - in_index: int = 2, - kernel_size: int = 3, - dilation: int | tuple[int, int] = 1, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.in_channels = config.hidden_size - self.channels = config.auxiliary_channels - self.num_convs = config.auxiliary_num_convs - self.concat_input = config.auxiliary_concat_input - self.in_index = in_index - - convs = [] - convs.append( - TFData2VecVisionConvModule( - in_channels=self.in_channels, - out_channels=self.channels, - kernel_size=kernel_size, - padding="same", - dilation=dilation, - name="convs.0", - ) - ) - for i in range(self.num_convs - 1): - convs.append( - TFData2VecVisionConvModule( - in_channels=self.channels, - out_channels=self.channels, - kernel_size=kernel_size, - padding="same", - dilation=dilation, - name=f"conv_module_{i + 2}", - ) - ) - if self.num_convs == 0: - self.convs = [tf.identity] - else: - self.convs = convs - if self.concat_input: - self.conv_cat = TFData2VecVisionConvModule( - self.in_channels + self.channels, - out_channels=self.channels, - kernel_size=kernel_size, - padding="same", - name="conv_cat", - ) - - self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier") - - def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor: - # just take the relevant feature maps - hidden_states = encoder_hidden_states[self.in_index] - output = hidden_states - for layer_module in self.convs: - output = layer_module(output) - if self.concat_input: - output = self.conv_cat(tf.concat([hidden_states, output], axis=-1)) - output = self.classifier(output) - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, None, self.channels]) - if getattr(self, "conv_cat", None) is not None: - with tf.name_scope(self.conv_cat.name): - self.conv_cat.build(None) - - -@add_start_docstrings( - """ - Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes. - """, - DATA2VEC_VISION_START_DOCSTRING, -) -class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel): - def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs) -> None: - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=False, name="data2vec_vision") - - # FPNs - self.fpn1 = [ - keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.0"), - keras.layers.BatchNormalization(name="fpn1.1", momentum=0.9, epsilon=1e-5), - keras.layers.Activation("gelu"), - keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.3"), - ] - self.fpn2 = [keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn2.0")] - - self.fpn3 = tf.identity - self.fpn4 = keras.layers.MaxPool2D(pool_size=2, strides=2) - - # Semantic segmentation head(s) - self.decode_head = TFData2VecVisionUperHead(config, name="decode_head") - self.auxiliary_head = ( - TFData2VecVisionFCNHead(config, name="auxiliary_head") if config.use_auxiliary_head else None - ) - - def compute_loss(self, logits, auxiliary_logits, labels): - # upsample logits to the images' original size - if len(shape_list(labels)) > 3: - label_interp_shape = shape_list(labels)[1:-1] - else: - label_interp_shape = shape_list(labels)[-2:] - - upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear") - if auxiliary_logits is not None: - upsampled_auxiliary_logits = tf.image.resize(auxiliary_logits, size=label_interp_shape, method="bilinear") - # compute weighted loss - loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") - - # Copied from https://www.tensorflow.org/text/tutorials/transformer#loss_and_metrics. - # Utility to mask the index to ignore during computing the loss. - def masked_loss(real, pred): - mask = tf.math.logical_not(tf.math.equal(real, self.config.semantic_loss_ignore_index)) - loss_ = loss_fct(real, pred) - mask = tf.cast(mask, dtype=loss_.dtype) - loss_ *= mask - reduced_masked_loss = tf.reduce_sum(loss_) / tf.reduce_sum(mask) - return tf.reshape(reduced_masked_loss, (1,)) - - main_loss = masked_loss(labels, upsampled_logits) - auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits) - loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss - - return loss - - @unpack_inputs - @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - labels: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - ) -> tuple | TFSemanticSegmenterOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*): - Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). - - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, TFData2VecVisionForSemanticSegmentation - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base") - >>> model = TFData2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base") - - >>> inputs = image_processor(images=image, return_tensors="pt") - >>> outputs = model(**inputs) - >>> # logits are of shape (batch_size, num_labels, height, width) - >>> logits = outputs.logits - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if labels is not None and self.config.num_labels == 1: - raise ValueError("The number of labels should be greater than one") - - outputs = self.data2vec_vision( - pixel_values, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=True, # we need the intermediate hidden states - return_dict=return_dict, - ) - encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] - - # only keep certain features, and reshape - # note that we do +1 as the encoder_hidden_states also includes the initial embeddings - features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices] - patch_resolution = self.config.image_size // self.config.patch_size - - def reshape_features(x): - # We do it this way so TF can always infer the non-batch dims at compile time - x = tf.reshape(x, (-1, patch_resolution, patch_resolution, self.config.hidden_size)) - return x - - features = [reshape_features(x[:, 1:, :]) for x in features] - - # apply FPNs - ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] - for module in ops[0]: - features[0] = module(features[0]) - features[1] = ops[1][0](features[1]) - for i in range(len(features[2:])): - features[i + 2] = ops[i + 2](features[i + 2]) - - logits = self.decode_head(features) - # Transpose the logits to maintain consistency in the output formats. - transposed_logits = tf.transpose(logits, perm=[0, 3, 1, 2]) - - auxiliary_logits = None - if self.auxiliary_head is not None: - auxiliary_logits = self.auxiliary_head(features) - - loss = None - if labels is not None: - loss = self.compute_loss(logits, auxiliary_logits, labels) - - if not return_dict: - if output_hidden_states: - output = (logits,) + outputs[1:] - else: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSemanticSegmenterOutput( - loss=loss, - logits=transposed_logits, - hidden_states=outputs.hidden_states if output_hidden_states else None, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "data2vec_vision", None) is not None: - with tf.name_scope(self.data2vec_vision.name): - self.data2vec_vision.build(None) - if getattr(self, "decode_head", None) is not None: - with tf.name_scope(self.decode_head.name): - self.decode_head.build(None) - if getattr(self, "auxiliary_head", None) is not None: - with tf.name_scope(self.auxiliary_head.name): - self.auxiliary_head.build(None) - if getattr(self, "fpn1", None) is not None: - with tf.name_scope(self.fpn1[0].name): - self.fpn1[0].build([None, None, None, self.config.hidden_size]) - with tf.name_scope(self.fpn1[1].name): - self.fpn1[1].build((None, None, None, self.config.hidden_size)) - with tf.name_scope(self.fpn1[3].name): - self.fpn1[3].build([None, None, None, self.config.hidden_size]) - if getattr(self, "fpn2", None) is not None: - with tf.name_scope(self.fpn2[0].name): - self.fpn2[0].build([None, None, None, self.config.hidden_size]) - - -__all__ = [ - "TFData2VecVisionForImageClassification", - "TFData2VecVisionForSemanticSegmentation", - "TFData2VecVisionModel", - "TFData2VecVisionPreTrainedModel", -] diff --git a/src/transformers/models/deberta/modeling_tf_deberta.py b/src/transformers/models/deberta/modeling_tf_deberta.py deleted file mode 100644 index 40d23fc28b94..000000000000 --- a/src/transformers/models/deberta/modeling_tf_deberta.py +++ /dev/null @@ -1,1652 +0,0 @@ -# coding=utf-8 -# Copyright 2021 Microsoft and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 DeBERTa model.""" - -from __future__ import annotations - -import math -from collections.abc import Sequence - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFMaskedLMOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_deberta import DebertaConfig - - -logger = logging.get_logger(__name__) - - -_CONFIG_FOR_DOC = "DebertaConfig" -_CHECKPOINT_FOR_DOC = "kamalkraj/deberta-base" - - -class TFDebertaContextPooler(keras.layers.Layer): - def __init__(self, config: DebertaConfig, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense(config.pooler_hidden_size, name="dense") - self.dropout = TFDebertaStableDropout(config.pooler_dropout, name="dropout") - self.config = config - - def call(self, hidden_states, training: bool = False): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - context_token = hidden_states[:, 0] - context_token = self.dropout(context_token, training=training) - pooled_output = self.dense(context_token) - pooled_output = get_tf_activation(self.config.pooler_hidden_act)(pooled_output) - return pooled_output - - @property - def output_dim(self) -> int: - return self.config.hidden_size - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.pooler_hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - - -class TFDebertaXSoftmax(keras.layers.Layer): - """ - Masked Softmax which is optimized for saving memory - - Args: - input (`tf.Tensor`): The input tensor that will apply softmax. - mask (`tf.Tensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation. - dim (int): The dimension that will apply softmax - """ - - def __init__(self, axis=-1, **kwargs): - super().__init__(**kwargs) - self.axis = axis - - def call(self, inputs: tf.Tensor, mask: tf.Tensor): - rmask = tf.logical_not(tf.cast(mask, tf.bool)) - output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs) - output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis) - output = tf.where(rmask, 0.0, output) - return output - - -class TFDebertaStableDropout(keras.layers.Layer): - """ - Optimized dropout module for stabilizing the training - - Args: - drop_prob (float): the dropout probabilities - """ - - def __init__(self, drop_prob, **kwargs): - super().__init__(**kwargs) - self.drop_prob = drop_prob - - @tf.custom_gradient - def xdropout(self, inputs): - """ - Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob. - """ - mask = tf.cast( - 1 - - tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)), - tf.bool, - ) - scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype) - if self.drop_prob > 0: - inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale - - def grad(upstream): - if self.drop_prob > 0: - return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale - else: - return upstream - - return inputs, grad - - def call(self, inputs: tf.Tensor, training: tf.Tensor = False): - if training: - return self.xdropout(inputs) - return inputs - - -class TFDebertaLayerNorm(keras.layers.Layer): - """LayerNorm module in the TF style (epsilon inside the square root).""" - - def __init__(self, size, eps=1e-12, **kwargs): - super().__init__(**kwargs) - self.size = size - self.eps = eps - - def build(self, input_shape): - self.gamma = self.add_weight(shape=[self.size], initializer=tf.ones_initializer(), name="weight") - self.beta = self.add_weight(shape=[self.size], initializer=tf.zeros_initializer(), name="bias") - return super().build(input_shape) - - def call(self, x: tf.Tensor) -> tf.Tensor: - mean = tf.reduce_mean(x, axis=[-1], keepdims=True) - variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True) - std = tf.math.sqrt(variance + self.eps) - return self.gamma * (x - mean) / std + self.beta - - -class TFDebertaSelfOutput(keras.layers.Layer): - def __init__(self, config: DebertaConfig, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense(config.hidden_size, name="dense") - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout") - self.config = config - - def call(self, hidden_states, input_tensor, training: bool = False): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - - -class TFDebertaAttention(keras.layers.Layer): - def __init__(self, config: DebertaConfig, **kwargs): - super().__init__(**kwargs) - self.self = TFDebertaDisentangledSelfAttention(config, name="self") - self.dense_output = TFDebertaSelfOutput(config, name="output") - self.config = config - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - query_states: tf.Tensor | None = None, - relative_pos: tf.Tensor | None = None, - rel_embeddings: tf.Tensor | None = None, - output_attentions: bool = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self( - hidden_states=input_tensor, - attention_mask=attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - training=training, - ) - if query_states is None: - query_states = input_tensor - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=query_states, training=training - ) - - output = (attention_output,) + self_outputs[1:] - - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self", None) is not None: - with tf.name_scope(self.self.name): - self.self.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -class TFDebertaIntermediate(keras.layers.Layer): - def __init__(self, config: DebertaConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFDebertaOutput(keras.layers.Layer): - def __init__(self, config: DebertaConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout") - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - - -class TFDebertaLayer(keras.layers.Layer): - def __init__(self, config: DebertaConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFDebertaAttention(config, name="attention") - self.intermediate = TFDebertaIntermediate(config, name="intermediate") - self.bert_output = TFDebertaOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - query_states: tf.Tensor | None = None, - relative_pos: tf.Tensor | None = None, - rel_embeddings: tf.Tensor | None = None, - output_attentions: bool = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - training=training, - ) - attention_output = attention_outputs[0] - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - - -class TFDebertaEncoder(keras.layers.Layer): - def __init__(self, config: DebertaConfig, **kwargs): - super().__init__(**kwargs) - - self.layer = [TFDebertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - self.relative_attention = getattr(config, "relative_attention", False) - self.config = config - if self.relative_attention: - self.max_relative_positions = getattr(config, "max_relative_positions", -1) - if self.max_relative_positions < 1: - self.max_relative_positions = config.max_position_embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if self.relative_attention: - self.rel_embeddings = self.add_weight( - name="rel_embeddings.weight", - shape=[self.max_relative_positions * 2, self.config.hidden_size], - initializer=get_initializer(self.config.initializer_range), - ) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - def get_rel_embedding(self): - rel_embeddings = self.rel_embeddings if self.relative_attention else None - return rel_embeddings - - def get_attention_mask(self, attention_mask): - if len(shape_list(attention_mask)) <= 2: - extended_attention_mask = tf.expand_dims(tf.expand_dims(attention_mask, 1), 2) - attention_mask = extended_attention_mask * tf.expand_dims(tf.squeeze(extended_attention_mask, -2), -1) - attention_mask = tf.cast(attention_mask, tf.uint8) - elif len(shape_list(attention_mask)) == 3: - attention_mask = tf.expand_dims(attention_mask, 1) - - return attention_mask - - def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): - if self.relative_attention and relative_pos is None: - q = shape_list(query_states)[-2] if query_states is not None else shape_list(hidden_states)[-2] - relative_pos = build_relative_position(q, shape_list(hidden_states)[-2]) - return relative_pos - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - query_states: tf.Tensor | None = None, - relative_pos: tf.Tensor | None = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - attention_mask = self.get_attention_mask(attention_mask) - relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) - - if isinstance(hidden_states, Sequence): - next_kv = hidden_states[0] - else: - next_kv = hidden_states - - rel_embeddings = self.get_rel_embedding() - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states=next_kv, - attention_mask=attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if query_states is not None: - query_states = hidden_states - if isinstance(hidden_states, Sequence): - next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None - else: - next_kv = hidden_states - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -def build_relative_position(query_size, key_size): - """ - Build relative position according to the query and key - - We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key - \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - - P_k\\) - - Args: - query_size (int): the length of query - key_size (int): the length of key - - Return: - `tf.Tensor`: A tensor with shape [1, query_size, key_size] - - """ - q_ids = tf.range(query_size, dtype=tf.int32) - k_ids = tf.range(key_size, dtype=tf.int32) - rel_pos_ids = q_ids[:, None] - tf.tile(tf.reshape(k_ids, [1, -1]), [query_size, 1]) - rel_pos_ids = rel_pos_ids[:query_size, :] - rel_pos_ids = tf.expand_dims(rel_pos_ids, axis=0) - return tf.cast(rel_pos_ids, tf.int64) - - -def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): - shapes = [ - shape_list(query_layer)[0], - shape_list(query_layer)[1], - shape_list(query_layer)[2], - shape_list(relative_pos)[-1], - ] - return tf.broadcast_to(c2p_pos, shapes) - - -def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): - shapes = [ - shape_list(query_layer)[0], - shape_list(query_layer)[1], - shape_list(key_layer)[-2], - shape_list(key_layer)[-2], - ] - return tf.broadcast_to(c2p_pos, shapes) - - -def pos_dynamic_expand(pos_index, p2c_att, key_layer): - shapes = shape_list(p2c_att)[:2] + [shape_list(pos_index)[-2], shape_list(key_layer)[-2]] - return tf.broadcast_to(pos_index, shapes) - - -def torch_gather(x, indices, gather_axis): - if gather_axis < 0: - gather_axis = tf.rank(x) + gather_axis - - if gather_axis != tf.rank(x) - 1: - pre_roll = tf.rank(x) - 1 - gather_axis - permutation = tf.roll(tf.range(tf.rank(x)), pre_roll, axis=0) - x = tf.transpose(x, perm=permutation) - indices = tf.transpose(indices, perm=permutation) - else: - pre_roll = 0 - - flat_x = tf.reshape(x, (-1, tf.shape(x)[-1])) - flat_indices = tf.reshape(indices, (-1, tf.shape(indices)[-1])) - gathered = tf.gather(flat_x, flat_indices, batch_dims=1) - gathered = tf.reshape(gathered, tf.shape(indices)) - - if pre_roll != 0: - permutation = tf.roll(tf.range(tf.rank(x)), -pre_roll, axis=0) - gathered = tf.transpose(gathered, perm=permutation) - - return gathered - - -class TFDebertaDisentangledSelfAttention(keras.layers.Layer): - """ - Disentangled self-attention module - - Parameters: - config (`str`): - A model config class instance with the configuration to build a new model. The schema is similar to - *BertConfig*, for more details, please refer [`DebertaConfig`] - - """ - - def __init__(self, config: DebertaConfig, **kwargs): - super().__init__(**kwargs) - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})" - ) - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.in_proj = keras.layers.Dense( - self.all_head_size * 3, - kernel_initializer=get_initializer(config.initializer_range), - name="in_proj", - use_bias=False, - ) - self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] - - self.relative_attention = getattr(config, "relative_attention", False) - self.talking_head = getattr(config, "talking_head", False) - - if self.talking_head: - self.head_logits_proj = keras.layers.Dense( - self.num_attention_heads, - kernel_initializer=get_initializer(config.initializer_range), - name="head_logits_proj", - use_bias=False, - ) - self.head_weights_proj = keras.layers.Dense( - self.num_attention_heads, - kernel_initializer=get_initializer(config.initializer_range), - name="head_weights_proj", - use_bias=False, - ) - - self.softmax = TFDebertaXSoftmax(axis=-1) - - if self.relative_attention: - self.max_relative_positions = getattr(config, "max_relative_positions", -1) - if self.max_relative_positions < 1: - self.max_relative_positions = config.max_position_embeddings - self.pos_dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="pos_dropout") - if "c2p" in self.pos_att_type: - self.pos_proj = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="pos_proj", - use_bias=False, - ) - if "p2c" in self.pos_att_type: - self.pos_q_proj = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="pos_q_proj" - ) - - self.dropout = TFDebertaStableDropout(config.attention_probs_dropout_prob, name="dropout") - self.config = config - - def build(self, input_shape=None): - if self.built: - return - self.built = True - self.q_bias = self.add_weight( - name="q_bias", shape=(self.all_head_size), initializer=keras.initializers.Zeros() - ) - self.v_bias = self.add_weight( - name="v_bias", shape=(self.all_head_size), initializer=keras.initializers.Zeros() - ) - if getattr(self, "in_proj", None) is not None: - with tf.name_scope(self.in_proj.name): - self.in_proj.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "head_logits_proj", None) is not None: - with tf.name_scope(self.head_logits_proj.name): - self.head_logits_proj.build(None) - if getattr(self, "head_weights_proj", None) is not None: - with tf.name_scope(self.head_weights_proj.name): - self.head_weights_proj.build(None) - if getattr(self, "pos_dropout", None) is not None: - with tf.name_scope(self.pos_dropout.name): - self.pos_dropout.build(None) - if getattr(self, "pos_proj", None) is not None: - with tf.name_scope(self.pos_proj.name): - self.pos_proj.build([self.config.hidden_size]) - if getattr(self, "pos_q_proj", None) is not None: - with tf.name_scope(self.pos_q_proj.name): - self.pos_q_proj.build([self.config.hidden_size]) - - def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor: - shape = shape_list(tensor)[:-1] + [self.num_attention_heads, -1] - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=shape) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - query_states: tf.Tensor | None = None, - relative_pos: tf.Tensor | None = None, - rel_embeddings: tf.Tensor | None = None, - output_attentions: bool = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - """ - Call the module - - Args: - hidden_states (`tf.Tensor`): - Input states to the module usually the output from previous layer, it will be the Q,K and V in - *Attention(Q,K,V)* - - attention_mask (`tf.Tensor`): - An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum - sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* - th token. - - return_att (`bool`, *optional*): - Whether return the attention matrix. - - query_states (`tf.Tensor`, *optional*): - The *Q* state in *Attention(Q,K,V)*. - - relative_pos (`tf.Tensor`): - The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with - values ranging in [*-max_relative_positions*, *max_relative_positions*]. - - rel_embeddings (`tf.Tensor`): - The embedding of relative distances. It's a tensor of shape [\\(2 \\times - \\text{max_relative_positions}\\), *hidden_size*]. - - - """ - if query_states is None: - qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) - query_layer, key_layer, value_layer = tf.split( - self.transpose_for_scores(qp), num_or_size_splits=3, axis=-1 - ) - else: - - def linear(w, b, x): - out = tf.matmul(x, w, transpose_b=True) - if b is not None: - out += tf.transpose(b) - return out - - ws = tf.split( - tf.transpose(self.in_proj.weight[0]), num_or_size_splits=self.num_attention_heads * 3, axis=0 - ) - qkvw = tf.TensorArray(dtype=self.dtype, size=3) - for k in tf.range(3): - qkvw_inside = tf.TensorArray(dtype=self.dtype, size=self.num_attention_heads) - for i in tf.range(self.num_attention_heads): - qkvw_inside = qkvw_inside.write(i, ws[i * 3 + k]) - qkvw = qkvw.write(k, qkvw_inside.concat()) - qkvb = [None] * 3 - - q = linear(qkvw[0], qkvb[0], query_states) - k = linear(qkvw[1], qkvb[1], hidden_states) - v = linear(qkvw[2], qkvb[2], hidden_states) - query_layer = self.transpose_for_scores(q) - key_layer = self.transpose_for_scores(k) - value_layer = self.transpose_for_scores(v) - - query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) - value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) - - rel_att = None - # Take the dot product between "query" and "key" to get the raw attention scores. - scale_factor = 1 + len(self.pos_att_type) - scale = math.sqrt(shape_list(query_layer)[-1] * scale_factor) - query_layer = query_layer / scale - - attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 1, 3, 2])) - if self.relative_attention: - rel_embeddings = self.pos_dropout(rel_embeddings, training=training) - rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) - - if rel_att is not None: - attention_scores = attention_scores + rel_att - - if self.talking_head: - attention_scores = tf.transpose( - self.head_logits_proj(tf.transpose(attention_scores, [0, 2, 3, 1])), [0, 3, 1, 2] - ) - - attention_probs = self.softmax(attention_scores, attention_mask) - attention_probs = self.dropout(attention_probs, training=training) - if self.talking_head: - attention_probs = tf.transpose( - self.head_weights_proj(tf.transpose(attention_probs, [0, 2, 3, 1])), [0, 3, 1, 2] - ) - - context_layer = tf.matmul(attention_probs, value_layer) - context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) - context_layer_shape = shape_list(context_layer) - # Set the final dimension here explicitly. - # Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing - # the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput - # requires final input dimension to be defined - new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]] - context_layer = tf.reshape(context_layer, new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs - - def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): - if relative_pos is None: - q = shape_list(query_layer)[-2] - relative_pos = build_relative_position(q, shape_list(key_layer)[-2]) - shape_list_pos = shape_list(relative_pos) - if len(shape_list_pos) == 2: - relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0) - elif len(shape_list_pos) == 3: - relative_pos = tf.expand_dims(relative_pos, 1) - # bxhxqxk - elif len(shape_list_pos) != 4: - raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}") - - att_span = tf.cast( - tf.minimum( - tf.maximum(shape_list(query_layer)[-2], shape_list(key_layer)[-2]), self.max_relative_positions - ), - tf.int64, - ) - rel_embeddings = tf.expand_dims( - rel_embeddings[self.max_relative_positions - att_span : self.max_relative_positions + att_span, :], 0 - ) - - score = 0 - - # content->position - if "c2p" in self.pos_att_type: - pos_key_layer = self.pos_proj(rel_embeddings) - pos_key_layer = self.transpose_for_scores(pos_key_layer) - c2p_att = tf.matmul(query_layer, tf.transpose(pos_key_layer, [0, 1, 3, 2])) - c2p_pos = tf.clip_by_value(relative_pos + att_span, 0, att_span * 2 - 1) - c2p_att = torch_gather(c2p_att, c2p_dynamic_expand(c2p_pos, query_layer, relative_pos), -1) - score += c2p_att - - # position->content - if "p2c" in self.pos_att_type: - pos_query_layer = self.pos_q_proj(rel_embeddings) - pos_query_layer = self.transpose_for_scores(pos_query_layer) - pos_query_layer /= tf.math.sqrt( - tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype) - ) - if shape_list(query_layer)[-2] != shape_list(key_layer)[-2]: - r_pos = build_relative_position(shape_list(key_layer)[-2], shape_list(key_layer)[-2]) - else: - r_pos = relative_pos - p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1) - p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 1, 3, 2])) - p2c_att = tf.transpose( - torch_gather(p2c_att, p2c_dynamic_expand(p2c_pos, query_layer, key_layer), -1), [0, 1, 3, 2] - ) - if shape_list(query_layer)[-2] != shape_list(key_layer)[-2]: - pos_index = tf.expand_dims(relative_pos[:, :, :, 0], -1) - p2c_att = torch_gather(p2c_att, pos_dynamic_expand(pos_index, p2c_att, key_layer), -2) - score += p2c_att - - return score - - -class TFDebertaEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = getattr(config, "embedding_size", config.hidden_size) - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.position_biased_input = getattr(config, "position_biased_input", True) - self.initializer_range = config.initializer_range - if self.embedding_size != config.hidden_size: - self.embed_proj = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="embed_proj", - use_bias=False, - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout") - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - if self.config.type_vocab_size > 0: - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - else: - self.token_type_embeddings = None - - with tf.name_scope("position_embeddings"): - if self.position_biased_input: - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - else: - self.position_embeddings = None - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "embed_proj", None) is not None: - with tf.name_scope(self.embed_proj.name): - self.embed_proj.build([None, None, self.embedding_size]) - - def call( - self, - input_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - mask: tf.Tensor | None = None, - training: bool = False, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - if input_ids is None and inputs_embeds is None: - raise ValueError("Need to provide either `input_ids` or `input_embeds`.") - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) - - final_embeddings = inputs_embeds - if self.position_biased_input: - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - final_embeddings += position_embeds - if self.config.type_vocab_size > 0: - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings += token_type_embeds - - if self.embedding_size != self.hidden_size: - final_embeddings = self.embed_proj(final_embeddings) - - final_embeddings = self.LayerNorm(final_embeddings) - - if mask is not None: - if len(shape_list(mask)) != len(shape_list(final_embeddings)): - if len(shape_list(mask)) == 4: - mask = tf.squeeze(tf.squeeze(mask, axis=1), axis=1) - mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype) - - final_embeddings = final_embeddings * mask - - final_embeddings = self.dropout(final_embeddings, training=training) - - return final_embeddings - - -class TFDebertaPredictionHeadTransform(keras.layers.Layer): - def __init__(self, config: DebertaConfig, **kwargs): - super().__init__(**kwargs) - - self.embedding_size = getattr(config, "embedding_size", config.hidden_size) - - self.dense = keras.layers.Dense( - units=self.embedding_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - - if isinstance(config.hidden_act, str): - self.transform_act_fn = get_tf_activation(config.hidden_act) - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.embedding_size]) - - -class TFDebertaLMPredictionHead(keras.layers.Layer): - def __init__(self, config: DebertaConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = getattr(config, "embedding_size", config.hidden_size) - - self.transform = TFDebertaPredictionHeadTransform(config, name="transform") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.input_embeddings = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "transform", None) is not None: - with tf.name_scope(self.transform.name): - self.transform.build(None) - - def get_output_embeddings(self) -> keras.layers.Layer: - return self.input_embeddings - - def set_output_embeddings(self, value: tf.Variable): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self) -> dict[str, tf.Variable]: - return {"bias": self.bias} - - def set_bias(self, value: tf.Variable): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.transform(hidden_states=hidden_states) - seq_length = shape_list(hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -class TFDebertaOnlyMLMHead(keras.layers.Layer): - def __init__(self, config: DebertaConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - self.predictions = TFDebertaLMPredictionHead(config, input_embeddings, name="predictions") - - def call(self, sequence_output: tf.Tensor) -> tf.Tensor: - prediction_scores = self.predictions(hidden_states=sequence_output) - - return prediction_scores - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - -# @keras_serializable -class TFDebertaMainLayer(keras.layers.Layer): - config_class = DebertaConfig - - def __init__(self, config: DebertaConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - - self.embeddings = TFDebertaEmbeddings(config, name="embeddings") - self.encoder = TFDebertaEncoder(config, name="encoder") - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - mask=attention_mask, - training=training, - ) - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - - if not return_dict: - return (sequence_output,) + encoder_outputs[1:] - - return TFBaseModelOutput( - last_hidden_state=sequence_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - - -class TFDebertaPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = DebertaConfig - base_model_prefix = "deberta" - - -DEBERTA_START_DOCSTRING = r""" - The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled - Attention](https://huggingface.co/papers/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build - on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two - improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`DebertaConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -DEBERTA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert *input_ids* indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput``] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", - DEBERTA_START_DOCSTRING, -) -class TFDebertaModel(TFDebertaPreTrainedModel): - def __init__(self, config: DebertaConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.deberta = TFDebertaMainLayer(config, name="deberta") - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - outputs = self.deberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deberta", None) is not None: - with tf.name_scope(self.deberta.name): - self.deberta.build(None) - - -@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) -class TFDebertaForMaskedLM(TFDebertaPreTrainedModel, TFMaskedLanguageModelingLoss): - def __init__(self, config: DebertaConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if config.is_decoder: - logger.warning( - "If you want to use `TFDebertaForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.deberta = TFDebertaMainLayer(config, name="deberta") - self.mlm = TFDebertaOnlyMLMHead(config, input_embeddings=self.deberta.embeddings, name="cls") - - def get_lm_head(self) -> keras.layers.Layer: - return self.mlm.predictions - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - outputs = self.deberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.mlm(sequence_output=sequence_output, training=training) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deberta", None) is not None: - with tf.name_scope(self.deberta.name): - self.deberta.build(None) - if getattr(self, "mlm", None) is not None: - with tf.name_scope(self.mlm.name): - self.mlm.build(None) - - -@add_start_docstrings( - """ - DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - DEBERTA_START_DOCSTRING, -) -class TFDebertaForSequenceClassification(TFDebertaPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: DebertaConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.deberta = TFDebertaMainLayer(config, name="deberta") - self.pooler = TFDebertaContextPooler(config, name="pooler") - - drop_out = getattr(config, "cls_dropout", None) - drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out - self.dropout = TFDebertaStableDropout(drop_out, name="cls_dropout") - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - ) - self.output_dim = self.pooler.output_dim - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.deberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - pooled_output = self.pooler(sequence_output, training=training) - pooled_output = self.dropout(pooled_output, training=training) - logits = self.classifier(pooled_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[1:] - - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deberta", None) is not None: - with tf.name_scope(self.deberta.name): - self.deberta.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.output_dim]) - - -@add_start_docstrings( - """ - DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - DEBERTA_START_DOCSTRING, -) -class TFDebertaForTokenClassification(TFDebertaPreTrainedModel, TFTokenClassificationLoss): - def __init__(self, config: DebertaConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.deberta = TFDebertaMainLayer(config, name="deberta") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.deberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(inputs=sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deberta", None) is not None: - with tf.name_scope(self.deberta.name): - self.deberta.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - DEBERTA_START_DOCSTRING, -) -class TFDebertaForQuestionAnswering(TFDebertaPreTrainedModel, TFQuestionAnsweringLoss): - def __init__(self, config: DebertaConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.deberta = TFDebertaMainLayer(config, name="deberta") - self.qa_outputs = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.deberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.qa_outputs(inputs=sequence_output) - start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) - start_logits = tf.squeeze(input=start_logits, axis=-1) - end_logits = tf.squeeze(input=end_logits, axis=-1) - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deberta", None) is not None: - with tf.name_scope(self.deberta.name): - self.deberta.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFDebertaForMaskedLM", - "TFDebertaForQuestionAnswering", - "TFDebertaForSequenceClassification", - "TFDebertaForTokenClassification", - "TFDebertaModel", - "TFDebertaPreTrainedModel", -] diff --git a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py deleted file mode 100644 index d71891ac19c0..000000000000 --- a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py +++ /dev/null @@ -1,1879 +0,0 @@ -# coding=utf-8 -# Copyright 2021 Microsoft and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 DeBERTa-v2 model.""" - -from __future__ import annotations - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_deberta_v2 import DebertaV2Config - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "DebertaV2Config" -_CHECKPOINT_FOR_DOC = "kamalkraj/deberta-v2-xlarge" - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaContextPooler with Deberta->DebertaV2 -class TFDebertaV2ContextPooler(keras.layers.Layer): - def __init__(self, config: DebertaV2Config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense(config.pooler_hidden_size, name="dense") - self.dropout = TFDebertaV2StableDropout(config.pooler_dropout, name="dropout") - self.config = config - - def call(self, hidden_states, training: bool = False): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - context_token = hidden_states[:, 0] - context_token = self.dropout(context_token, training=training) - pooled_output = self.dense(context_token) - pooled_output = get_tf_activation(self.config.pooler_hidden_act)(pooled_output) - return pooled_output - - @property - def output_dim(self) -> int: - return self.config.hidden_size - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.pooler_hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaXSoftmax with Deberta->DebertaV2 -class TFDebertaV2XSoftmax(keras.layers.Layer): - """ - Masked Softmax which is optimized for saving memory - - Args: - input (`tf.Tensor`): The input tensor that will apply softmax. - mask (`tf.Tensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation. - dim (int): The dimension that will apply softmax - """ - - def __init__(self, axis=-1, **kwargs): - super().__init__(**kwargs) - self.axis = axis - - def call(self, inputs: tf.Tensor, mask: tf.Tensor): - rmask = tf.logical_not(tf.cast(mask, tf.bool)) - output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs) - output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis) - output = tf.where(rmask, 0.0, output) - return output - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaStableDropout with Deberta->DebertaV2 -class TFDebertaV2StableDropout(keras.layers.Layer): - """ - Optimized dropout module for stabilizing the training - - Args: - drop_prob (float): the dropout probabilities - """ - - def __init__(self, drop_prob, **kwargs): - super().__init__(**kwargs) - self.drop_prob = drop_prob - - @tf.custom_gradient - def xdropout(self, inputs): - """ - Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob. - """ - mask = tf.cast( - 1 - - tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)), - tf.bool, - ) - scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype) - if self.drop_prob > 0: - inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale - - def grad(upstream): - if self.drop_prob > 0: - return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale - else: - return upstream - - return inputs, grad - - def call(self, inputs: tf.Tensor, training: tf.Tensor = False): - if training: - return self.xdropout(inputs) - return inputs - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaSelfOutput with Deberta->DebertaV2 -class TFDebertaV2SelfOutput(keras.layers.Layer): - def __init__(self, config: DebertaV2Config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense(config.hidden_size, name="dense") - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout") - self.config = config - - def call(self, hidden_states, input_tensor, training: bool = False): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaAttention with Deberta->DebertaV2 -class TFDebertaV2Attention(keras.layers.Layer): - def __init__(self, config: DebertaV2Config, **kwargs): - super().__init__(**kwargs) - self.self = TFDebertaV2DisentangledSelfAttention(config, name="self") - self.dense_output = TFDebertaV2SelfOutput(config, name="output") - self.config = config - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - query_states: tf.Tensor | None = None, - relative_pos: tf.Tensor | None = None, - rel_embeddings: tf.Tensor | None = None, - output_attentions: bool = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self( - hidden_states=input_tensor, - attention_mask=attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - training=training, - ) - if query_states is None: - query_states = input_tensor - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=query_states, training=training - ) - - output = (attention_output,) + self_outputs[1:] - - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self", None) is not None: - with tf.name_scope(self.self.name): - self.self.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaIntermediate with Deberta->DebertaV2 -class TFDebertaV2Intermediate(keras.layers.Layer): - def __init__(self, config: DebertaV2Config, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaOutput with Deberta->DebertaV2 -class TFDebertaV2Output(keras.layers.Layer): - def __init__(self, config: DebertaV2Config, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout") - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaLayer with Deberta->DebertaV2 -class TFDebertaV2Layer(keras.layers.Layer): - def __init__(self, config: DebertaV2Config, **kwargs): - super().__init__(**kwargs) - - self.attention = TFDebertaV2Attention(config, name="attention") - self.intermediate = TFDebertaV2Intermediate(config, name="intermediate") - self.bert_output = TFDebertaV2Output(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - query_states: tf.Tensor | None = None, - relative_pos: tf.Tensor | None = None, - rel_embeddings: tf.Tensor | None = None, - output_attentions: bool = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - training=training, - ) - attention_output = attention_outputs[0] - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - - -class TFDebertaV2ConvLayer(keras.layers.Layer): - def __init__(self, config: DebertaV2Config, **kwargs): - super().__init__(**kwargs) - - self.kernel_size = getattr(config, "conv_kernel_size", 3) - # groups = getattr(config, "conv_groups", 1) - self.conv_act = get_tf_activation(getattr(config, "conv_act", "tanh")) - self.padding = (self.kernel_size - 1) // 2 - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout") - self.config = config - - def build(self, input_shape=None): - if self.built: - return - self.built = True - with tf.name_scope("conv"): - self.conv_kernel = self.add_weight( - name="kernel", - shape=[self.kernel_size, self.config.hidden_size, self.config.hidden_size], - initializer=get_initializer(self.config.initializer_range), - ) - self.conv_bias = self.add_weight( - name="bias", shape=[self.config.hidden_size], initializer=tf.zeros_initializer() - ) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - - def call( - self, hidden_states: tf.Tensor, residual_states: tf.Tensor, input_mask: tf.Tensor, training: bool = False - ) -> tf.Tensor: - out = tf.nn.conv2d( - tf.expand_dims(hidden_states, 1), - tf.expand_dims(self.conv_kernel, 0), - strides=1, - padding=[[0, 0], [0, 0], [self.padding, self.padding], [0, 0]], - ) - out = tf.squeeze(tf.nn.bias_add(out, self.conv_bias), 1) - rmask = tf.cast(1 - input_mask, tf.bool) - out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out) - out = self.dropout(out, training=training) - out = self.conv_act(out) - - layer_norm_input = residual_states + out - output = self.LayerNorm(layer_norm_input) - - if input_mask is None: - output_states = output - else: - if len(shape_list(input_mask)) != len(shape_list(layer_norm_input)): - if len(shape_list(input_mask)) == 4: - input_mask = tf.squeeze(tf.squeeze(input_mask, axis=1), axis=1) - input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), dtype=self.compute_dtype) - - output_states = output * input_mask - - return output_states - - -class TFDebertaV2Encoder(keras.layers.Layer): - def __init__(self, config: DebertaV2Config, **kwargs): - super().__init__(**kwargs) - - self.layer = [TFDebertaV2Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - self.relative_attention = getattr(config, "relative_attention", False) - self.config = config - if self.relative_attention: - self.max_relative_positions = getattr(config, "max_relative_positions", -1) - if self.max_relative_positions < 1: - self.max_relative_positions = config.max_position_embeddings - - self.position_buckets = getattr(config, "position_buckets", -1) - self.pos_ebd_size = self.max_relative_positions * 2 - - if self.position_buckets > 0: - self.pos_ebd_size = self.position_buckets * 2 - - self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] - - if "layer_norm" in self.norm_rel_ebd: - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - - self.conv = TFDebertaV2ConvLayer(config, name="conv") if getattr(config, "conv_kernel_size", 0) > 0 else None - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if self.relative_attention: - self.rel_embeddings = self.add_weight( - name="rel_embeddings.weight", - shape=[self.pos_ebd_size, self.config.hidden_size], - initializer=get_initializer(self.config.initializer_range), - ) - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build(None) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, self.config.hidden_size]) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - def get_rel_embedding(self): - rel_embeddings = self.rel_embeddings if self.relative_attention else None - if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): - rel_embeddings = self.LayerNorm(rel_embeddings) - return rel_embeddings - - def get_attention_mask(self, attention_mask): - if len(shape_list(attention_mask)) <= 2: - extended_attention_mask = tf.expand_dims(tf.expand_dims(attention_mask, 1), 2) - attention_mask = extended_attention_mask * tf.expand_dims(tf.squeeze(extended_attention_mask, -2), -1) - attention_mask = tf.cast(attention_mask, tf.uint8) - elif len(shape_list(attention_mask)) == 3: - attention_mask = tf.expand_dims(attention_mask, 1) - - return attention_mask - - def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): - if self.relative_attention and relative_pos is None: - q = shape_list(query_states)[-2] if query_states is not None else shape_list(hidden_states)[-2] - relative_pos = build_relative_position( - q, - shape_list(hidden_states)[-2], - bucket_size=self.position_buckets, - max_position=self.max_relative_positions, - ) - return relative_pos - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - query_states: tf.Tensor | None = None, - relative_pos: tf.Tensor | None = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - if len(shape_list(attention_mask)) <= 2: - input_mask = attention_mask - else: - input_mask = tf.cast(tf.math.reduce_sum(attention_mask, axis=-2) > 0, dtype=tf.uint8) - - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - attention_mask = self.get_attention_mask(attention_mask) - relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) - - next_kv = hidden_states - - rel_embeddings = self.get_rel_embedding() - output_states = next_kv - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (output_states,) - - layer_outputs = layer_module( - hidden_states=next_kv, - attention_mask=attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - training=training, - ) - output_states = layer_outputs[0] - - if i == 0 and self.conv is not None: - output_states = self.conv(hidden_states, output_states, input_mask) - - next_kv = output_states - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (output_states,) - - if not return_dict: - return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -def make_log_bucket_position(relative_pos, bucket_size, max_position): - sign = tf.math.sign(relative_pos) - mid = bucket_size // 2 - abs_pos = tf.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, tf.math.abs(relative_pos)) - log_pos = tf.math.ceil( - tf.cast(tf.math.log(abs_pos / mid), tf.float32) - / tf.cast(tf.math.log((max_position - 1) / mid), tf.float32) - * tf.cast(mid - 1, tf.float32) # in graph mode - ) + tf.cast(mid, tf.float32) - bucket_pos = tf.cast( - tf.where(abs_pos <= mid, tf.cast(relative_pos, tf.float32), log_pos * tf.cast(sign, tf.float32)), tf.int32 - ) - return bucket_pos - - -def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): - """ - Build relative position according to the query and key - - We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key - \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - - P_k\\) - - Args: - query_size (int): the length of query - key_size (int): the length of key - bucket_size (int): the size of position bucket - max_position (int): the maximum allowed absolute position - - Return: - `tf.Tensor`: A tensor with shape [1, query_size, key_size] - - """ - q_ids = tf.range(query_size, dtype=tf.int32) - k_ids = tf.range(key_size, dtype=tf.int32) - rel_pos_ids = q_ids[:, None] - tf.tile(tf.expand_dims(k_ids, axis=0), [shape_list(q_ids)[0], 1]) - if bucket_size > 0 and max_position > 0: - rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) - rel_pos_ids = rel_pos_ids[:query_size, :] - rel_pos_ids = tf.expand_dims(rel_pos_ids, axis=0) - return tf.cast(rel_pos_ids, tf.int64) - - -def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): - shapes = [ - shape_list(query_layer)[0], - shape_list(query_layer)[1], - shape_list(query_layer)[2], - shape_list(relative_pos)[-1], - ] - return tf.broadcast_to(c2p_pos, shapes) - - -def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): - shapes = [ - shape_list(query_layer)[0], - shape_list(query_layer)[1], - shape_list(key_layer)[-2], - shape_list(key_layer)[-2], - ] - return tf.broadcast_to(c2p_pos, shapes) - - -def pos_dynamic_expand(pos_index, p2c_att, key_layer): - shapes = shape_list(p2c_att)[:2] + [shape_list(pos_index)[-2], shape_list(key_layer)[-2]] - return tf.broadcast_to(pos_index, shapes) - - -def take_along_axis(x, indices): - # Only a valid port of np.take_along_axis when the gather axis is -1 - - # TPU + gathers and reshapes don't go along well -- see https://github.com/huggingface/transformers/issues/18239 - if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy): - # [B, S, P] -> [B, S, P, D] - one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype) - - # if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x) - # grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P] - gathered = tf.einsum("ijkl,ijl->ijk", one_hot_indices, x) - - # GPUs, on the other hand, prefer gathers instead of large one-hot+matmuls - else: - gathered = tf.gather(x, indices, batch_dims=2) - - return gathered - - -class TFDebertaV2DisentangledSelfAttention(keras.layers.Layer): - """ - Disentangled self-attention module - - Parameters: - config (`DebertaV2Config`): - A model config class instance with the configuration to build a new model. The schema is similar to - *BertConfig*, for more details, please refer [`DebertaV2Config`] - - """ - - def __init__(self, config: DebertaV2Config, **kwargs): - super().__init__(**kwargs) - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})" - ) - self.num_attention_heads = config.num_attention_heads - _attention_head_size = config.hidden_size // config.num_attention_heads - self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query_proj = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="query_proj", - use_bias=True, - ) - self.key_proj = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="key_proj", - use_bias=True, - ) - self.value_proj = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="value_proj", - use_bias=True, - ) - - self.share_att_key = getattr(config, "share_att_key", False) - self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] - self.relative_attention = getattr(config, "relative_attention", False) - - if self.relative_attention: - self.position_buckets = getattr(config, "position_buckets", -1) - self.max_relative_positions = getattr(config, "max_relative_positions", -1) - if self.max_relative_positions < 1: - self.max_relative_positions = config.max_position_embeddings - self.pos_ebd_size = self.max_relative_positions - if self.position_buckets > 0: - self.pos_ebd_size = self.position_buckets - - self.pos_dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="pos_dropout") - - if not self.share_att_key: - if "c2p" in self.pos_att_type: - self.pos_key_proj = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="pos_proj", - use_bias=True, - ) - if "p2c" in self.pos_att_type: - self.pos_query_proj = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="pos_q_proj", - ) - self.softmax = TFDebertaV2XSoftmax(axis=-1) - self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout") - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor: - tensor_shape = shape_list(tensor) - # In graph mode mode, we can't reshape with -1 as the final dimension if the first dimension (batch size) is None - shape = tensor_shape[:-1] + [attention_heads, tensor_shape[-1] // attention_heads] - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=shape) - tensor = tf.transpose(tensor, perm=[0, 2, 1, 3]) - x_shape = shape_list(tensor) - tensor = tf.reshape(tensor, shape=[-1, x_shape[-2], x_shape[-1]]) - return tensor - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - query_states: tf.Tensor | None = None, - relative_pos: tf.Tensor | None = None, - rel_embeddings: tf.Tensor | None = None, - output_attentions: bool = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - """ - Call the module - - Args: - hidden_states (`tf.Tensor`): - Input states to the module usually the output from previous layer, it will be the Q,K and V in - *Attention(Q,K,V)* - - attention_mask (`tf.Tensor`): - An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum - sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* - th token. - - return_att (`bool`, *optional*): - Whether return the attention matrix. - - query_states (`tf.Tensor`, *optional*): - The *Q* state in *Attention(Q,K,V)*. - - relative_pos (`tf.Tensor`): - The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with - values ranging in [*-max_relative_positions*, *max_relative_positions*]. - - rel_embeddings (`tf.Tensor`): - The embedding of relative distances. It's a tensor of shape [\\(2 \\times - \\text{max_relative_positions}\\), *hidden_size*]. - - - """ - if query_states is None: - query_states = hidden_states - query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) - key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) - value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) - - rel_att = None - # Take the dot product between "query" and "key" to get the raw attention scores. - scale_factor = 1 - if "c2p" in self.pos_att_type: - scale_factor += 1 - if "p2c" in self.pos_att_type: - scale_factor += 1 - scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, dtype=self.compute_dtype)) - attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 2, 1]) / scale) - if self.relative_attention: - rel_embeddings = self.pos_dropout(rel_embeddings) - rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) - - if rel_att is not None: - attention_scores = attention_scores + rel_att - attention_scores = tf.reshape( - attention_scores, - (-1, self.num_attention_heads, shape_list(attention_scores)[-2], shape_list(attention_scores)[-1]), - ) - - # bsz x height x length x dimension - attention_probs = self.softmax(attention_scores, attention_mask) - attention_probs = self.dropout(attention_probs, training=training) - context_layer = tf.matmul( - tf.reshape(attention_probs, [-1, shape_list(attention_probs)[-2], shape_list(attention_probs)[-1]]), - value_layer, - ) - context_layer = tf.transpose( - tf.reshape( - context_layer, - [-1, self.num_attention_heads, shape_list(context_layer)[-2], shape_list(context_layer)[-1]], - ), - [0, 2, 1, 3], - ) - # Set the final dimension here explicitly. - # Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing - # the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput - # requires final input dimension to be defined - context_layer_shape = shape_list(context_layer) - new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]] - context_layer = tf.reshape(context_layer, new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs - - def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): - if relative_pos is None: - q = shape_list(query_layer)[-2] - relative_pos = build_relative_position( - q, - shape_list(key_layer)[-2], - bucket_size=self.position_buckets, - max_position=self.max_relative_positions, - ) - shape_list_pos = shape_list(relative_pos) - if len(shape_list_pos) == 2: - relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0) - elif len(shape_list_pos) == 3: - relative_pos = tf.expand_dims(relative_pos, 1) - # bsz x height x query x key - elif len(shape_list_pos) != 4: - raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}") - - att_span = self.pos_ebd_size - rel_embeddings = tf.expand_dims( - rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :], 0 - ) - if self.share_att_key: - pos_query_layer = tf.tile( - self.transpose_for_scores(self.query_proj(rel_embeddings), self.num_attention_heads), - [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1], - ) - pos_key_layer = tf.tile( - self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads), - [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1], - ) - else: - if "c2p" in self.pos_att_type: - pos_key_layer = tf.tile( - self.transpose_for_scores(self.pos_key_proj(rel_embeddings), self.num_attention_heads), - [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1], - ) # .split(self.all_head_size, dim=-1) - if "p2c" in self.pos_att_type: - pos_query_layer = tf.tile( - self.transpose_for_scores(self.pos_query_proj(rel_embeddings), self.num_attention_heads), - [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1], - ) # .split(self.all_head_size, dim=-1) - - score = 0 - # content->position - if "c2p" in self.pos_att_type: - scale = tf.math.sqrt(tf.cast(shape_list(pos_key_layer)[-1] * scale_factor, dtype=self.compute_dtype)) - c2p_att = tf.matmul(query_layer, tf.transpose(pos_key_layer, [0, 2, 1])) - c2p_pos = tf.clip_by_value(relative_pos + att_span, 0, att_span * 2 - 1) - c2p_att = take_along_axis( - c2p_att, - tf.broadcast_to( - tf.squeeze(c2p_pos, 0), - [shape_list(query_layer)[0], shape_list(query_layer)[1], shape_list(relative_pos)[-1]], - ), - ) - score += c2p_att / scale - - # position->content - if "p2c" in self.pos_att_type: - scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype)) - if shape_list(key_layer)[-2] != shape_list(query_layer)[-2]: - r_pos = build_relative_position( - shape_list(key_layer)[-2], - shape_list(key_layer)[-2], - bucket_size=self.position_buckets, - max_position=self.max_relative_positions, - ) - r_pos = tf.expand_dims(r_pos, 0) - else: - r_pos = relative_pos - - p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1) - - p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 2, 1])) - p2c_att = tf.transpose( - take_along_axis( - p2c_att, - tf.broadcast_to( - tf.squeeze(p2c_pos, 0), - [shape_list(query_layer)[0], shape_list(key_layer)[-2], shape_list(key_layer)[-2]], - ), - ), - [0, 2, 1], - ) - score += p2c_att / scale - - return score - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query_proj", None) is not None: - with tf.name_scope(self.query_proj.name): - self.query_proj.build([None, None, self.config.hidden_size]) - if getattr(self, "key_proj", None) is not None: - with tf.name_scope(self.key_proj.name): - self.key_proj.build([None, None, self.config.hidden_size]) - if getattr(self, "value_proj", None) is not None: - with tf.name_scope(self.value_proj.name): - self.value_proj.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "pos_dropout", None) is not None: - with tf.name_scope(self.pos_dropout.name): - self.pos_dropout.build(None) - if getattr(self, "pos_key_proj", None) is not None: - with tf.name_scope(self.pos_key_proj.name): - self.pos_key_proj.build([None, None, self.config.hidden_size]) - if getattr(self, "pos_query_proj", None) is not None: - with tf.name_scope(self.pos_query_proj.name): - self.pos_query_proj.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaEmbeddings Deberta->DebertaV2 -class TFDebertaV2Embeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = getattr(config, "embedding_size", config.hidden_size) - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.position_biased_input = getattr(config, "position_biased_input", True) - self.initializer_range = config.initializer_range - if self.embedding_size != config.hidden_size: - self.embed_proj = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="embed_proj", - use_bias=False, - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout") - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - if self.config.type_vocab_size > 0: - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - else: - self.token_type_embeddings = None - - with tf.name_scope("position_embeddings"): - if self.position_biased_input: - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - else: - self.position_embeddings = None - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "embed_proj", None) is not None: - with tf.name_scope(self.embed_proj.name): - self.embed_proj.build([None, None, self.embedding_size]) - - def call( - self, - input_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - mask: tf.Tensor | None = None, - training: bool = False, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - if input_ids is None and inputs_embeds is None: - raise ValueError("Need to provide either `input_ids` or `input_embeds`.") - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) - - final_embeddings = inputs_embeds - if self.position_biased_input: - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - final_embeddings += position_embeds - if self.config.type_vocab_size > 0: - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings += token_type_embeds - - if self.embedding_size != self.hidden_size: - final_embeddings = self.embed_proj(final_embeddings) - - final_embeddings = self.LayerNorm(final_embeddings) - - if mask is not None: - if len(shape_list(mask)) != len(shape_list(final_embeddings)): - if len(shape_list(mask)) == 4: - mask = tf.squeeze(tf.squeeze(mask, axis=1), axis=1) - mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype) - - final_embeddings = final_embeddings * mask - - final_embeddings = self.dropout(final_embeddings, training=training) - - return final_embeddings - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaPredictionHeadTransform with Deberta->DebertaV2 -class TFDebertaV2PredictionHeadTransform(keras.layers.Layer): - def __init__(self, config: DebertaV2Config, **kwargs): - super().__init__(**kwargs) - - self.embedding_size = getattr(config, "embedding_size", config.hidden_size) - - self.dense = keras.layers.Dense( - units=self.embedding_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - - if isinstance(config.hidden_act, str): - self.transform_act_fn = get_tf_activation(config.hidden_act) - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.embedding_size]) - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaLMPredictionHead with Deberta->DebertaV2 -class TFDebertaV2LMPredictionHead(keras.layers.Layer): - def __init__(self, config: DebertaV2Config, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = getattr(config, "embedding_size", config.hidden_size) - - self.transform = TFDebertaV2PredictionHeadTransform(config, name="transform") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.input_embeddings = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "transform", None) is not None: - with tf.name_scope(self.transform.name): - self.transform.build(None) - - def get_output_embeddings(self) -> keras.layers.Layer: - return self.input_embeddings - - def set_output_embeddings(self, value: tf.Variable): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self) -> dict[str, tf.Variable]: - return {"bias": self.bias} - - def set_bias(self, value: tf.Variable): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.transform(hidden_states=hidden_states) - seq_length = shape_list(hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaOnlyMLMHead with Deberta->DebertaV2 -class TFDebertaV2OnlyMLMHead(keras.layers.Layer): - def __init__(self, config: DebertaV2Config, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - self.predictions = TFDebertaV2LMPredictionHead(config, input_embeddings, name="predictions") - - def call(self, sequence_output: tf.Tensor) -> tf.Tensor: - prediction_scores = self.predictions(hidden_states=sequence_output) - - return prediction_scores - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaMainLayer with Deberta->DebertaV2 -class TFDebertaV2MainLayer(keras.layers.Layer): - config_class = DebertaV2Config - - def __init__(self, config: DebertaV2Config, **kwargs): - super().__init__(**kwargs) - - self.config = config - - self.embeddings = TFDebertaV2Embeddings(config, name="embeddings") - self.encoder = TFDebertaV2Encoder(config, name="encoder") - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - mask=attention_mask, - training=training, - ) - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - - if not return_dict: - return (sequence_output,) + encoder_outputs[1:] - - return TFBaseModelOutput( - last_hidden_state=sequence_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - - -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaPreTrainedModel with Deberta->DebertaV2 -class TFDebertaV2PreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = DebertaV2Config - base_model_prefix = "deberta" - - -DEBERTA_START_DOCSTRING = r""" - The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled - Attention](https://huggingface.co/papers/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build - on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two - improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -DEBERTA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert *input_ids* indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput``] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", - DEBERTA_START_DOCSTRING, -) -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaModel with Deberta->DebertaV2 -class TFDebertaV2Model(TFDebertaV2PreTrainedModel): - def __init__(self, config: DebertaV2Config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.deberta = TFDebertaV2MainLayer(config, name="deberta") - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - outputs = self.deberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deberta", None) is not None: - with tf.name_scope(self.deberta.name): - self.deberta.build(None) - - -@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForMaskedLM with Deberta->DebertaV2 -class TFDebertaV2ForMaskedLM(TFDebertaV2PreTrainedModel, TFMaskedLanguageModelingLoss): - def __init__(self, config: DebertaV2Config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if config.is_decoder: - logger.warning( - "If you want to use `TFDebertaV2ForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.deberta = TFDebertaV2MainLayer(config, name="deberta") - self.mlm = TFDebertaV2OnlyMLMHead(config, input_embeddings=self.deberta.embeddings, name="cls") - - def get_lm_head(self) -> keras.layers.Layer: - return self.mlm.predictions - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - outputs = self.deberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.mlm(sequence_output=sequence_output, training=training) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deberta", None) is not None: - with tf.name_scope(self.deberta.name): - self.deberta.build(None) - if getattr(self, "mlm", None) is not None: - with tf.name_scope(self.mlm.name): - self.mlm.build(None) - - -@add_start_docstrings( - """ - DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - DEBERTA_START_DOCSTRING, -) -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForSequenceClassification with Deberta->DebertaV2 -class TFDebertaV2ForSequenceClassification(TFDebertaV2PreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: DebertaV2Config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.deberta = TFDebertaV2MainLayer(config, name="deberta") - self.pooler = TFDebertaV2ContextPooler(config, name="pooler") - - drop_out = getattr(config, "cls_dropout", None) - drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out - self.dropout = TFDebertaV2StableDropout(drop_out, name="cls_dropout") - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - ) - self.output_dim = self.pooler.output_dim - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.deberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - pooled_output = self.pooler(sequence_output, training=training) - pooled_output = self.dropout(pooled_output, training=training) - logits = self.classifier(pooled_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[1:] - - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deberta", None) is not None: - with tf.name_scope(self.deberta.name): - self.deberta.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.output_dim]) - - -@add_start_docstrings( - """ - DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - DEBERTA_START_DOCSTRING, -) -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForTokenClassification with Deberta->DebertaV2 -class TFDebertaV2ForTokenClassification(TFDebertaV2PreTrainedModel, TFTokenClassificationLoss): - def __init__(self, config: DebertaV2Config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.deberta = TFDebertaV2MainLayer(config, name="deberta") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.deberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(inputs=sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deberta", None) is not None: - with tf.name_scope(self.deberta.name): - self.deberta.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - DEBERTA_START_DOCSTRING, -) -# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForQuestionAnswering with Deberta->DebertaV2 -class TFDebertaV2ForQuestionAnswering(TFDebertaV2PreTrainedModel, TFQuestionAnsweringLoss): - def __init__(self, config: DebertaV2Config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.deberta = TFDebertaV2MainLayer(config, name="deberta") - self.qa_outputs = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.deberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.qa_outputs(inputs=sequence_output) - start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) - start_logits = tf.squeeze(input=start_logits, axis=-1) - end_logits = tf.squeeze(input=end_logits, axis=-1) - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deberta", None) is not None: - with tf.name_scope(self.deberta.name): - self.deberta.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - DEBERTA_START_DOCSTRING, -) -class TFDebertaV2ForMultipleChoice(TFDebertaV2PreTrainedModel, TFMultipleChoiceLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - # _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"] - # _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config: DebertaV2Config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.deberta = TFDebertaV2MainLayer(config, name="deberta") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.pooler = TFDebertaV2ContextPooler(config, name="pooler") - self.classifier = keras.layers.Dense( - units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.output_dim = self.pooler.output_dim - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None - flat_attention_mask = ( - tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None - ) - flat_token_type_ids = ( - tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None - ) - flat_position_ids = ( - tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None - ) - flat_inputs_embeds = ( - tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - outputs = self.deberta( - input_ids=flat_input_ids, - attention_mask=flat_attention_mask, - token_type_ids=flat_token_type_ids, - position_ids=flat_position_ids, - inputs_embeds=flat_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - pooled_output = self.pooler(sequence_output, training=training) - pooled_output = self.dropout(pooled_output, training=training) - logits = self.classifier(pooled_output) - reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deberta", None) is not None: - with tf.name_scope(self.deberta.name): - self.deberta.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.output_dim]) - - -__all__ = [ - "TFDebertaV2ForMaskedLM", - "TFDebertaV2ForQuestionAnswering", - "TFDebertaV2ForMultipleChoice", - "TFDebertaV2ForSequenceClassification", - "TFDebertaV2ForTokenClassification", - "TFDebertaV2Model", - "TFDebertaV2PreTrainedModel", -] diff --git a/src/transformers/models/deit/modeling_tf_deit.py b/src/transformers/models/deit/modeling_tf_deit.py deleted file mode 100644 index 3c56eee87911..000000000000 --- a/src/transformers/models/deit/modeling_tf_deit.py +++ /dev/null @@ -1,1232 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Facebook AI Research (FAIR) and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow DeiT model.""" - -from __future__ import annotations - -import collections.abc -import math -from dataclasses import dataclass - -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, - TFImageClassifierOutput, - TFMaskedImageModelingOutput, -) -from ...modeling_tf_utils import ( - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_deit import DeiTConfig - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "DeiTConfig" - -# Base docstring -_CHECKPOINT_FOR_DOC = "facebook/deit-base-distilled-patch16-224" -_EXPECTED_OUTPUT_SHAPE = [1, 198, 768] - -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "facebook/deit-base-distilled-patch16-224" -_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" - - -@dataclass -class TFDeiTForImageClassificationWithTeacherOutput(ModelOutput): - """ - Output type of [`DeiTForImageClassificationWithTeacher`]. - - Args: - logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Prediction scores as the average of the cls_logits and distillation logits. - cls_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the - class token). - distillation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the - distillation token). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus - the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - """ - - logits: tf.Tensor | None = None - cls_logits: tf.Tensor | None = None - distillation_logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -class TFDeiTEmbeddings(keras.layers.Layer): - """ - Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token. - """ - - def __init__(self, config: DeiTConfig, use_mask_token: bool = False, **kwargs) -> None: - super().__init__(**kwargs) - self.config = config - self.use_mask_token = use_mask_token - self.patch_embeddings = TFDeiTPatchEmbeddings(config=config, name="patch_embeddings") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout") - - def build(self, input_shape=None): - self.cls_token = self.add_weight( - shape=(1, 1, self.config.hidden_size), - initializer=keras.initializers.zeros(), - trainable=True, - name="cls_token", - ) - self.distillation_token = self.add_weight( - shape=(1, 1, self.config.hidden_size), - initializer=keras.initializers.zeros(), - trainable=True, - name="distillation_token", - ) - self.mask_token = None - if self.use_mask_token: - self.mask_token = self.add_weight( - shape=(1, 1, self.config.hidden_size), - initializer=keras.initializers.zeros(), - trainable=True, - name="mask_token", - ) - num_patches = self.patch_embeddings.num_patches - self.position_embeddings = self.add_weight( - shape=(1, num_patches + 2, self.config.hidden_size), - initializer=keras.initializers.zeros(), - trainable=True, - name="position_embeddings", - ) - - if self.built: - return - self.built = True - if getattr(self, "patch_embeddings", None) is not None: - with tf.name_scope(self.patch_embeddings.name): - self.patch_embeddings.build(None) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - - def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor: - num_patches = embeddings.shape[1] - 2 - num_positions = self.position_embeddings.shape[1] - 2 - - if num_patches == num_positions and height == width: - return self.position_embeddings - - class_pos_embed = self.position_embeddings[:, 0, :] - dist_pos_embed = self.position_embeddings[:, 1, :] - patch_pos_embed = self.position_embeddings[:, 2:, :] - dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # # we add a small number to avoid floating point error in the interpolation - # # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = tf.reshape( - patch_pos_embed, (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) - ) - patch_pos_embed = tf.image.resize(patch_pos_embed, size=(int(h0), int(w0)), method="bicubic") - patch_pos_embed = tf.transpose(patch_pos_embed, perm=[0, 2, 3, 1]) - patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, dim)) - - return tf.concat( - [tf.expand_dims(class_pos_embed, axis=0), tf.expand_dims(dist_pos_embed, axis=0), patch_pos_embed], axis=1 - ) - - def call( - self, - pixel_values: tf.Tensor, - bool_masked_pos: tf.Tensor | None = None, - training: bool = False, - interpolate_pos_encoding: bool = False, - ) -> tf.Tensor: - _, height, width, _ = pixel_values.shape - - embeddings = self.patch_embeddings(pixel_values) - batch_size, seq_length, _ = shape_list(embeddings) - - if bool_masked_pos is not None: - mask_tokens = tf.tile(self.mask_token, [batch_size, seq_length, 1]) - # replace the masked visual tokens by mask_tokens - mask = tf.expand_dims(bool_masked_pos, axis=-1) - mask = tf.cast(mask, dtype=mask_tokens.dtype) - embeddings = embeddings * (1.0 - mask) + mask_tokens * mask - - cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0) - distillation_tokens = tf.repeat(self.distillation_token, repeats=batch_size, axis=0) - embeddings = tf.concat((cls_tokens, distillation_tokens, embeddings), axis=1) - position_embedding = self.position_embeddings - if interpolate_pos_encoding: - position_embedding = self.interpolate_pos_encoding(embeddings, height, width) - - embeddings = embeddings + position_embedding - embeddings = self.dropout(embeddings, training=training) - return embeddings - - -class TFDeiTPatchEmbeddings(keras.layers.Layer): - """ - This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial - `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a - Transformer. - """ - - def __init__(self, config: DeiTConfig, **kwargs) -> None: - super().__init__(**kwargs) - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.hidden_size - - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_patches = num_patches - - self.projection = keras.layers.Conv2D( - hidden_size, kernel_size=patch_size, strides=patch_size, name="projection" - ) - - def call(self, pixel_values: tf.Tensor) -> tf.Tensor: - batch_size, height, width, num_channels = shape_list(pixel_values) - if tf.executing_eagerly() and num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - - x = self.projection(pixel_values) - batch_size, height, width, num_channels = shape_list(x) - x = tf.reshape(x, (batch_size, height * width, num_channels)) - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, None, self.num_channels]) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfAttention with ViT->DeiT -class TFDeiTSelfAttention(keras.layers.Layer): - def __init__(self, config: DeiTConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - mixed_key_layer = self.key(inputs=hidden_states) - mixed_value_layer = self.value(inputs=hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->DeiT -class TFDeiTSelfOutput(keras.layers.Layer): - """ - The residual connection is defined in TFDeiTLayer instead of here (as is the case with other models), due to the - layernorm applied before each block. - """ - - def __init__(self, config: DeiTConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->DeiT -class TFDeiTAttention(keras.layers.Layer): - def __init__(self, config: DeiTConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFDeiTSelfAttention(config, name="attention") - self.dense_output = TFDeiTSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->DeiT -class TFDeiTIntermediate(keras.layers.Layer): - def __init__(self, config: DeiTConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->DeiT -class TFDeiTOutput(keras.layers.Layer): - def __init__(self, config: DeiTConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = hidden_states + input_tensor - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - - -class TFDeiTLayer(keras.layers.Layer): - """This corresponds to the Block class in the timm implementation.""" - - def __init__(self, config: DeiTConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFDeiTAttention(config, name="attention") - self.intermediate = TFDeiTIntermediate(config, name="intermediate") - self.deit_output = TFDeiTOutput(config, name="output") - - self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") - self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - attention_outputs = self.attention( - # in DeiT, layernorm is applied before self-attention - input_tensor=self.layernorm_before(inputs=hidden_states, training=training), - head_mask=head_mask, - output_attentions=output_attentions, - training=training, - ) - attention_output = attention_outputs[0] - - # first residual connection - hidden_states = attention_output + hidden_states - - # in DeiT, layernorm is also applied after self-attention - layer_output = self.layernorm_after(inputs=hidden_states, training=training) - - intermediate_output = self.intermediate(hidden_states=layer_output, training=training) - - # second residual connection is done here - layer_output = self.deit_output( - hidden_states=intermediate_output, input_tensor=hidden_states, training=training - ) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "deit_output", None) is not None: - with tf.name_scope(self.deit_output.name): - self.deit_output.build(None) - if getattr(self, "layernorm_before", None) is not None: - with tf.name_scope(self.layernorm_before.name): - self.layernorm_before.build([None, None, self.config.hidden_size]) - if getattr(self, "layernorm_after", None) is not None: - with tf.name_scope(self.layernorm_after.name): - self.layernorm_after.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTEncoder with ViT->DeiT -class TFDeiTEncoder(keras.layers.Layer): - def __init__(self, config: DeiTConfig, **kwargs): - super().__init__(**kwargs) - - self.layer = [TFDeiTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states=hidden_states, - head_mask=head_mask[i], - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFDeiTMainLayer(keras.layers.Layer): - config_class = DeiTConfig - - def __init__( - self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs - ) -> None: - super().__init__(**kwargs) - self.config = config - - self.embeddings = TFDeiTEmbeddings(config, use_mask_token=use_mask_token, name="embeddings") - self.encoder = TFDeiTEncoder(config, name="encoder") - - self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - self.pooler = TFDeiTPooler(config, name="pooler") if add_pooling_layer else None - - def get_input_embeddings(self) -> TFDeiTPatchEmbeddings: - return self.embeddings.patch_embeddings - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - def get_head_mask(self, head_mask): - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - return head_mask - - @unpack_inputs - def call( - self, - pixel_values: tf.Tensor | None = None, - bool_masked_pos: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - interpolate_pos_encoding: bool = False, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor, ...]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - # TF 2.0 image layers can't use NCHW format when running on CPU. - # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) - pixel_values = tf.transpose(pixel_values, (0, 2, 3, 1)) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask) - - embedding_output = self.embeddings( - pixel_values, - bool_masked_pos=bool_masked_pos, - training=training, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - encoder_outputs = self.encoder( - embedding_output, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output, training=training) - pooled_output = self.pooler(sequence_output, training=training) if self.pooler is not None else None - - if not return_dict: - head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) - return head_outputs + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, self.config.hidden_size]) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTPreTrainedModel with ViT->DeiT all-casing -class TFDeiTPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = DeiTConfig - base_model_prefix = "deit" - main_input_name = "pixel_values" - - -DEIT_START_DOCSTRING = r""" - This model is a TensorFlow - [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer). Use it as a regular - TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior. - - Parameters: - config ([`DeiTConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -DEIT_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`DeiTImageProcessor.__call__`] for details. - - head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): - Whether to interpolate the pre-trained position encodings. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare DeiT Model transformer outputting raw hidden-states without any specific head on top.", - DEIT_START_DOCSTRING, -) -class TFDeiTModel(TFDeiTPreTrainedModel): - def __init__( - self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs - ) -> None: - super().__init__(config, **kwargs) - - self.deit = TFDeiTMainLayer( - config, add_pooling_layer=add_pooling_layer, use_mask_token=use_mask_token, name="deit" - ) - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPooling, - config_class=_CONFIG_FOR_DOC, - modality="vision", - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - def call( - self, - pixel_values: tf.Tensor | None = None, - bool_masked_pos: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - interpolate_pos_encoding: bool = False, - training: bool = False, - ) -> tuple | TFBaseModelOutputWithPooling: - outputs = self.deit( - pixel_values=pixel_values, - bool_masked_pos=bool_masked_pos, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - interpolate_pos_encoding=interpolate_pos_encoding, - training=training, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deit", None) is not None: - with tf.name_scope(self.deit.name): - self.deit.build(None) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTPooler with ViT->DeiT -class TFDeiTPooler(keras.layers.Layer): - def __init__(self, config: DeiTConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.pooler_output_size, - kernel_initializer=get_initializer(config.initializer_range), - activation=config.pooler_act, - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFDeitPixelShuffle(keras.layers.Layer): - """TF layer implementation of torch.nn.PixelShuffle""" - - def __init__(self, upscale_factor: int, **kwargs) -> None: - super().__init__(**kwargs) - if not isinstance(upscale_factor, int) or upscale_factor < 2: - raise ValueError(f"upscale_factor must be an integer value >= 2 got {upscale_factor}") - self.upscale_factor = upscale_factor - - def call(self, x: tf.Tensor) -> tf.Tensor: - hidden_states = x - batch_size, _, _, num_input_channels = shape_list(hidden_states) - block_size_squared = self.upscale_factor**2 - output_depth = int(num_input_channels / block_size_squared) - # When the number of output channels >= 2, PyTorch's PixelShuffle and - # TF's depth_to_space differ in their output as the order of channels selected for combining - # is a permutation of the other c.f. - # https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1 - permutation = tf.constant( - [[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]] - ) - hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1) - hidden_states = tf.nn.depth_to_space(hidden_states, block_size=self.upscale_factor, data_format="NHWC") - return hidden_states - - -class TFDeitDecoder(keras.layers.Layer): - def __init__(self, config: DeiTConfig, **kwargs) -> None: - super().__init__(**kwargs) - self.conv2d = keras.layers.Conv2D( - filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, name="0" - ) - self.pixel_shuffle = TFDeitPixelShuffle(config.encoder_stride, name="1") - self.config = config - - def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = inputs - hidden_states = self.conv2d(hidden_states) - hidden_states = self.pixel_shuffle(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv2d", None) is not None: - with tf.name_scope(self.conv2d.name): - self.conv2d.build([None, None, None, self.config.hidden_size]) - if getattr(self, "pixel_shuffle", None) is not None: - with tf.name_scope(self.pixel_shuffle.name): - self.pixel_shuffle.build(None) - - -@add_start_docstrings( - "DeiT Model with a decoder on top for masked image modeling, as proposed in" - " [SimMIM](https://huggingface.co/papers/2111.09886).", - DEIT_START_DOCSTRING, -) -class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): - def __init__(self, config: DeiTConfig) -> None: - super().__init__(config) - - self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="deit") - self.decoder = TFDeitDecoder(config, name="decoder") - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: tf.Tensor | None = None, - bool_masked_pos: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - interpolate_pos_encoding: bool = False, - training: bool = False, - ) -> tuple | TFMaskedImageModelingOutput: - r""" - bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`): - Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). - - Returns: - - Examples: - ```python - >>> from transformers import AutoImageProcessor, TFDeiTForMaskedImageModeling - >>> import tensorflow as tf - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224") - >>> model = TFDeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224") - - >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 - >>> pixel_values = image_processor(images=image, return_tensors="tf").pixel_values - >>> # create random boolean mask of shape (batch_size, num_patches) - >>> bool_masked_pos = tf.cast(tf.random.uniform((1, num_patches), minval=0, maxval=2, dtype=tf.int32), tf.bool) - - >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) - >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction - >>> list(reconstructed_pixel_values.shape) - [1, 3, 224, 224] - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.deit( - pixel_values, - bool_masked_pos=bool_masked_pos, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - interpolate_pos_encoding=interpolate_pos_encoding, - training=training, - ) - - sequence_output = outputs[0] - - # Reshape to (batch_size, num_channels, height, width) - sequence_output = sequence_output[:, 1:-1] - batch_size, sequence_length, num_channels = shape_list(sequence_output) - height = width = int(sequence_length**0.5) - sequence_output = tf.reshape(sequence_output, (batch_size, height, width, num_channels)) - - # Reconstruct pixel values - reconstructed_pixel_values = self.decoder(sequence_output, training=training) - # TF 2.0 image layers can't use NCHW format when running on CPU, so intermediate layers use NHWC, - # including the decoder. We transpose to compute the loss against the pixel values - # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width) - reconstructed_pixel_values = tf.transpose(reconstructed_pixel_values, (0, 3, 1, 2)) - - masked_im_loss = None - if bool_masked_pos is not None: - size = self.config.image_size // self.config.patch_size - bool_masked_pos = tf.reshape(bool_masked_pos, (-1, size, size)) - mask = tf.repeat(bool_masked_pos, self.config.patch_size, 1) - mask = tf.repeat(mask, self.config.patch_size, 2) - mask = tf.expand_dims(mask, 1) - mask = tf.cast(mask, tf.float32) - - reconstruction_loss = keras.losses.mean_absolute_error( - # Swap axes as metric calculation reduces over the final dimension - tf.transpose(pixel_values, (1, 2, 3, 0)), - tf.transpose(reconstructed_pixel_values, (1, 2, 3, 0)), - ) - reconstruction_loss = tf.expand_dims(reconstruction_loss, 0) - total_loss = tf.reduce_sum(reconstruction_loss * mask) - num_masked_pixels = (tf.reduce_sum(mask) + 1e-5) * self.config.num_channels - masked_im_loss = total_loss / num_masked_pixels - masked_im_loss = tf.reshape(masked_im_loss, (1,)) - - if not return_dict: - output = (reconstructed_pixel_values,) + outputs[1:] - return ((masked_im_loss,) + output) if masked_im_loss is not None else output - - return TFMaskedImageModelingOutput( - loss=masked_im_loss, - reconstruction=reconstructed_pixel_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deit", None) is not None: - with tf.name_scope(self.deit.name): - self.deit.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - """ - DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of - the [CLS] token) e.g. for ImageNet. - """, - DEIT_START_DOCSTRING, -) -class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: DeiTConfig): - super().__init__(config) - - self.num_labels = config.num_labels - self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, name="deit") - - # Classifier head - self.classifier = ( - keras.layers.Dense(config.num_labels, name="classifier") - if config.num_labels > 0 - else keras.layers.Activation("linear", name="classifier") - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFImageClassifierOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - labels: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - interpolate_pos_encoding: bool = False, - training: bool = False, - ) -> tf.Tensor | TFImageClassifierOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, TFDeiTForImageClassification - >>> import tensorflow as tf - >>> from PIL import Image - >>> import requests - - >>> keras.utils.set_random_seed(3) # doctest: +IGNORE_RESULT - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> # note: we are loading a TFDeiTForImageClassificationWithTeacher from the hub here, - >>> # so the head will be randomly initialized, hence the predictions will be random - >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224") - >>> model = TFDeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224") - - >>> inputs = image_processor(images=image, return_tensors="tf") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] - >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)]) - Predicted class: little blue heron, Egretta caerulea - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.deit( - pixel_values, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - interpolate_pos_encoding=interpolate_pos_encoding, - training=training, - ) - - sequence_output = outputs[0] - - logits = self.classifier(sequence_output[:, 0, :]) - # we don't use the distillation token - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deit", None) is not None: - with tf.name_scope(self.deit.name): - self.deit.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of - the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. - - .. warning:: - - This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet - supported. - """, - DEIT_START_DOCSTRING, -) -class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel): - def __init__(self, config: DeiTConfig) -> None: - super().__init__(config) - - self.num_labels = config.num_labels - self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, name="deit") - - # Classifier heads - self.cls_classifier = ( - keras.layers.Dense(config.num_labels, name="cls_classifier") - if config.num_labels > 0 - else keras.layers.Activation("linear", name="cls_classifier") - ) - self.distillation_classifier = ( - keras.layers.Dense(config.num_labels, name="distillation_classifier") - if config.num_labels > 0 - else keras.layers.Activation("linear", name="distillation_classifier") - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=TFDeiTForImageClassificationWithTeacherOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, - ) - def call( - self, - pixel_values: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - interpolate_pos_encoding: bool = False, - training: bool = False, - ) -> tuple | TFDeiTForImageClassificationWithTeacherOutput: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.deit( - pixel_values, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - interpolate_pos_encoding=interpolate_pos_encoding, - training=training, - ) - - sequence_output = outputs[0] - - cls_logits = self.cls_classifier(sequence_output[:, 0, :]) - distillation_logits = self.distillation_classifier(sequence_output[:, 1, :]) - - # during inference, return the average of both classifier predictions - logits = (cls_logits + distillation_logits) / 2 - - if not return_dict: - output = (logits, cls_logits, distillation_logits) + outputs[1:] - return output - - return TFDeiTForImageClassificationWithTeacherOutput( - logits=logits, - cls_logits=cls_logits, - distillation_logits=distillation_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "deit", None) is not None: - with tf.name_scope(self.deit.name): - self.deit.build(None) - if getattr(self, "cls_classifier", None) is not None: - with tf.name_scope(self.cls_classifier.name): - self.cls_classifier.build([None, None, self.config.hidden_size]) - if getattr(self, "distillation_classifier", None) is not None: - with tf.name_scope(self.distillation_classifier.name): - self.distillation_classifier.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFDeiTForImageClassification", - "TFDeiTForImageClassificationWithTeacher", - "TFDeiTForMaskedImageModeling", - "TFDeiTModel", - "TFDeiTPreTrainedModel", -] diff --git a/src/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py b/src/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py deleted file mode 100644 index 643097e79c3e..000000000000 --- a/src/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py +++ /dev/null @@ -1,1198 +0,0 @@ -# coding=utf-8 -# Copyright 2023 Snapchat Research and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow EfficientFormer model.""" - -import itertools -from dataclasses import dataclass -from typing import Optional, Union - -import tensorflow as tf - -from ....activations_tf import ACT2FN -from ....modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, - TFImageClassifierOutput, -) -from ....modeling_tf_utils import ( - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ....tf_utils import shape_list, stable_softmax -from ....utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_efficientformer import EfficientFormerConfig - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "EfficientFormerConfig" - -# Base docstring -_CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300" -_EXPECTED_OUTPUT_SHAPE = [1, 49, 448] - -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300" -_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_281" - - -class TFEfficientFormerPatchEmbeddings(keras.layers.Layer): - """ - This class performs downsampling between two stages. For the input tensor with the shape [batch_size, num_channels, - height, width] it produces output tensor with the shape [batch_size, num_channels, height/stride, width/stride] - """ - - def __init__( - self, config: EfficientFormerConfig, num_channels: int, embed_dim: int, apply_norm: bool = True, **kwargs - ) -> None: - super().__init__(**kwargs) - self.num_channels = num_channels - - self.padding = keras.layers.ZeroPadding2D(padding=config.downsample_pad) - self.projection = keras.layers.Conv2D( - filters=embed_dim, - kernel_size=config.downsample_patch_size, - strides=config.downsample_stride, - padding="valid", - name="projection", - ) - # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization - self.norm = ( - keras.layers.BatchNormalization(axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="norm") - if apply_norm - else tf.identity - ) - self.embed_dim = embed_dim - - def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: - tf.debugging.assert_shapes( - [(pixel_values, (..., None, None, self.num_channels))], - message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.", - ) - embeddings = self.projection(self.padding(pixel_values)) - embeddings = self.norm(embeddings, training=training) - return embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, None, self.num_channels]) - if getattr(self, "norm", None) is not None: - if hasattr(self.norm, "name"): - with tf.name_scope(self.norm.name): - self.norm.build([None, None, None, self.embed_dim]) - - -class TFEfficientFormerSelfAttention(keras.layers.Layer): - def __init__( - self, - dim: int, - key_dim: int, - num_heads: int, - attention_ratio: int, - resolution: int, - config: EfficientFormerConfig, - **kwargs, - ): - super().__init__(**kwargs) - - self.num_heads = num_heads - self.key_dim = key_dim - self.attention_ratio = attention_ratio - self.scale = key_dim**-0.5 - self.total_key_dim = key_dim * num_heads - self.expanded_key_dim = int(attention_ratio * key_dim) - self.total_expanded_key_dim = int(self.expanded_key_dim * num_heads) - hidden_size = self.total_expanded_key_dim + self.total_key_dim * 2 - - self.qkv = keras.layers.Dense( - units=hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="qkv" - ) - self.projection = keras.layers.Dense( - units=dim, kernel_initializer=get_initializer(config.initializer_range), name="projection" - ) - self.resolution = resolution - self.dim = dim - - def build(self, input_shape: tf.TensorShape) -> None: - points = list(itertools.product(range(self.resolution), range(self.resolution))) - num_points = len(points) - attention_offsets = {} - - idxs = [] - - for point_1 in points: - for point_2 in points: - offset = (abs(point_1[0] - point_2[0]), abs(point_1[1] - point_2[1])) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - - self.attention_biases = self.add_weight( - shape=(self.num_heads, len(attention_offsets)), - initializer=keras.initializers.zeros(), - trainable=True, - name="attention_biases", - ) - self.attention_bias_idxs = self.add_weight( - shape=(num_points, num_points), - trainable=False, - dtype=tf.int32, - name="attention_bias_idxs", - ) - - self.attention_bias_idxs.assign(tf.reshape(tf.cast(idxs, dtype=tf.int32), (num_points, num_points))) - - if self.built: - return - self.built = True - if getattr(self, "qkv", None) is not None: - with tf.name_scope(self.qkv.name): - self.qkv.build([None, None, self.dim]) - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, self.total_expanded_key_dim]) - - def call( - self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False - ) -> tuple[tf.Tensor]: - batch_size, sequence_length, *_ = shape_list(hidden_states) - qkv = self.qkv(inputs=hidden_states) - - query_layer, key_layer, value_layer = tf.split( - tf.reshape(tensor=qkv, shape=(batch_size, sequence_length, self.num_heads, -1)), - num_or_size_splits=[self.key_dim, self.key_dim, self.expanded_key_dim], - axis=3, - ) - - query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3]) - key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3]) - value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3]) - - attention_probs = tf.matmul(query_layer, tf.transpose(key_layer, perm=[0, 1, 3, 2])) - scale = tf.cast(self.scale, dtype=attention_probs.dtype) - attention_probs = tf.multiply(attention_probs, scale) - - attention_biases = tf.gather(params=self.attention_biases, indices=self.attention_bias_idxs, axis=1) - attention_probs = attention_probs + attention_biases - attention_probs = stable_softmax(logits=attention_probs, axis=-1) - - context_layer = tf.matmul(attention_probs, value_layer) - context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) - - context_layer = tf.reshape( - tensor=context_layer, shape=(batch_size, sequence_length, self.total_expanded_key_dim) - ) - context_layer = self.projection(context_layer) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -class TFEfficientFormerConvStem(keras.layers.Layer): - def __init__(self, config: EfficientFormerConfig, out_channels: int, **kwargs): - super().__init__(**kwargs) - - self.padding = keras.layers.ZeroPadding2D(padding=1) - self.convolution1 = keras.layers.Conv2D( - filters=out_channels // 2, kernel_size=3, strides=2, padding="valid", name="convolution1" - ) - # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization - self.batchnorm_before = keras.layers.BatchNormalization( - axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before" - ) - - self.convolution2 = keras.layers.Conv2D( - filters=out_channels, - kernel_size=3, - strides=2, - padding="valid", - name="convolution2", - ) - # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization - self.batchnorm_after = keras.layers.BatchNormalization( - axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after" - ) - - self.activation = keras.layers.Activation(activation=keras.activations.relu, name="activation") - self.out_channels = out_channels - self.config = config - - def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: - features = self.batchnorm_before(self.convolution1(self.padding(pixel_values)), training=training) - features = self.activation(features) - features = self.batchnorm_after(self.convolution2(self.padding(features)), training=training) - features = self.activation(features) - return features - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convolution1", None) is not None: - with tf.name_scope(self.convolution1.name): - self.convolution1.build([None, None, None, self.config.num_channels]) - if getattr(self, "batchnorm_before", None) is not None: - with tf.name_scope(self.batchnorm_before.name): - self.batchnorm_before.build([None, None, None, self.out_channels // 2]) - if getattr(self, "convolution2", None) is not None: - with tf.name_scope(self.convolution2.name): - self.convolution2.build([None, None, None, self.out_channels // 2]) - if getattr(self, "batchnorm_after", None) is not None: - with tf.name_scope(self.batchnorm_after.name): - self.batchnorm_after.build([None, None, None, self.out_channels]) - if getattr(self, "activation", None) is not None: - with tf.name_scope(self.activation.name): - self.activation.build(None) - - -class TFEfficientFormerPooling(keras.layers.Layer): - def __init__(self, pool_size: int, **kwargs): - super().__init__(**kwargs) - self.pool = keras.layers.AveragePooling2D(pool_size=pool_size, strides=1, padding="same") - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - output = self.pool(hidden_states) - output = output - hidden_states - return output - - -class TFEfficientFormerDenseMlp(keras.layers.Layer): - def __init__( - self, - config: EfficientFormerConfig, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - **kwargs, - ): - super().__init__(**kwargs) - out_features = out_features or in_features - hidden_features = hidden_features or in_features - - self.linear_in = keras.layers.Dense( - units=hidden_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_in" - ) - self.activation = ACT2FN[config.hidden_act] - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - self.linear_out = keras.layers.Dense( - units=out_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_out" - ) - self.hidden_features = hidden_features - self.in_features = in_features - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.linear_in(inputs=hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.linear_out(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "linear_in", None) is not None: - with tf.name_scope(self.linear_in.name): - self.linear_in.build([None, None, self.in_features]) - if getattr(self, "linear_out", None) is not None: - with tf.name_scope(self.linear_out.name): - self.linear_out.build([None, None, self.hidden_features]) - - -class TFEfficientFormerConvMlp(keras.layers.Layer): - def __init__( - self, - config: EfficientFormerConfig, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - drop: float = 0.0, - **kwargs, - ): - super().__init__(**kwargs) - out_features = out_features or in_features - hidden_features = hidden_features or in_features - - self.convolution1 = keras.layers.Conv2D( - filters=hidden_features, - kernel_size=1, - name="convolution1", - padding="valid", - ) - - self.activation = ACT2FN[config.hidden_act] - - self.convolution2 = keras.layers.Conv2D( - filters=out_features, - kernel_size=1, - name="convolution2", - padding="valid", - ) - - self.dropout = keras.layers.Dropout(rate=drop) - - # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization - self.batchnorm_before = keras.layers.BatchNormalization( - axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before" - ) - # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization - self.batchnorm_after = keras.layers.BatchNormalization( - axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after" - ) - self.hidden_features = hidden_features - self.in_features = in_features - self.out_features = out_features - - def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_state = self.convolution1(hidden_state) - hidden_state = self.batchnorm_before(hidden_state, training=training) - hidden_state = self.activation(hidden_state) - hidden_state = self.dropout(hidden_state, training=training) - hidden_state = self.convolution2(hidden_state) - hidden_state = self.batchnorm_after(hidden_state, training=training) - hidden_state = self.dropout(hidden_state, training=training) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convolution1", None) is not None: - with tf.name_scope(self.convolution1.name): - self.convolution1.build([None, None, None, self.in_features]) - if getattr(self, "convolution2", None) is not None: - with tf.name_scope(self.convolution2.name): - self.convolution2.build([None, None, None, self.hidden_features]) - if getattr(self, "batchnorm_before", None) is not None: - with tf.name_scope(self.batchnorm_before.name): - self.batchnorm_before.build([None, None, None, self.hidden_features]) - if getattr(self, "batchnorm_after", None) is not None: - with tf.name_scope(self.batchnorm_after.name): - self.batchnorm_after.build([None, None, None, self.out_features]) - - -# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->EfficientFormer -class TFEfficientFormerDropPath(keras.layers.Layer): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - References: - (1) github.com:rwightman/pytorch-image-models - """ - - def __init__(self, drop_path: float, **kwargs): - super().__init__(**kwargs) - self.drop_path = drop_path - - def call(self, x: tf.Tensor, training=None): - if training: - keep_prob = 1 - self.drop_path - shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) - random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) - random_tensor = tf.floor(random_tensor) - return (x / keep_prob) * random_tensor - return x - - -class TFEfficientFormerFlat(keras.layers.Layer): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def call(self, hidden_states: tf.Tensor) -> tuple[tf.Tensor]: - batch_size, _, _, in_channels = shape_list(hidden_states) - hidden_states = tf.reshape(hidden_states, shape=[batch_size, -1, in_channels]) - return hidden_states - - -class TFEfficientFormerMeta3D(keras.layers.Layer): - def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs): - super().__init__(**kwargs) - - self.token_mixer = TFEfficientFormerSelfAttention( - dim=config.dim, - key_dim=config.key_dim, - num_heads=config.num_attention_heads, - attention_ratio=config.attention_ratio, - resolution=config.resolution, - name="token_mixer", - config=config, - ) - self.dim = dim - self.config = config - - self.layernorm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm1") - self.layernorm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm2") - mlp_hidden_dim = int(dim * config.mlp_expansion_ratio) - self.mlp = TFEfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim, name="mlp") - - # Using `layers.Activation` instead of `tf.identity` to better control `training' behavior. - self.drop_path = ( - TFEfficientFormerDropPath(drop_path) - if drop_path > 0.0 - else keras.layers.Activation("linear", name="drop_path") - ) - self.config = config - - def build(self, input_shape=None): - self.layer_scale_1 = None - self.layer_scale_2 = None - - if self.config.use_layer_scale: - self.layer_scale_1 = self.add_weight( - shape=(self.dim,), - initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value), - trainable=True, - name="layer_scale_1", - ) - self.layer_scale_2 = self.add_weight( - shape=(self.dim,), - initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value), - trainable=True, - name="layer_scale_2", - ) - - if self.built: - return - self.built = True - if getattr(self, "token_mixer", None) is not None: - with tf.name_scope(self.token_mixer.name): - self.token_mixer.build(None) - if getattr(self, "layernorm1", None) is not None: - with tf.name_scope(self.layernorm1.name): - self.layernorm1.build([None, None, self.dim]) - if getattr(self, "layernorm2", None) is not None: - with tf.name_scope(self.layernorm2.name): - self.layernorm2.build([None, None, self.dim]) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "drop_path", None) is not None: - with tf.name_scope(self.drop_path.name): - self.drop_path.build(None) - - def call( - self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False - ) -> tuple[tf.Tensor]: - self_attention_outputs = self.token_mixer( - hidden_states=self.layernorm1(hidden_states, training=training), - output_attentions=output_attentions, - training=training, - ) - - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - if self.config.use_layer_scale: - layer_output = hidden_states + self.drop_path( - tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * attention_output, - training=training, - ) - layer_output = layer_output + self.drop_path( - tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0) - * self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training), - training=training, - ) - else: - layer_output = hidden_states + self.drop_path(attention_output, training=training) - layer_output = layer_output + self.drop_path( - self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training), - training=training, - ) - - outputs = (layer_output,) + outputs - - return outputs - - -class TFEfficientFormerMeta3DLayers(keras.layers.Layer): - def __init__(self, config: EfficientFormerConfig, **kwargs): - super().__init__(**kwargs) - drop_paths = [ - config.drop_path_rate * (block_idx + sum(config.depths[:-1])) - for block_idx in range(config.num_meta3d_blocks) - ] - self.blocks = [ - TFEfficientFormerMeta3D(config, config.hidden_sizes[-1], drop_path=drop_path, name=f"blocks.{i}") - for i, drop_path in enumerate(drop_paths) - ] - - def call( - self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False - ) -> tuple[tf.Tensor]: - all_attention_outputs = () if output_attentions else None - - for i, layer_module in enumerate(self.blocks): - if isinstance(hidden_states, tuple): - hidden_states = hidden_states[0] - - hidden_states = layer_module( - hidden_states=hidden_states, output_attentions=output_attentions, training=training - ) - if output_attentions: - all_attention_outputs = all_attention_outputs + (hidden_states[1],) - - if output_attentions: - outputs = (hidden_states[0],) + all_attention_outputs - return outputs - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "blocks", None) is not None: - for layer in self.blocks: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFEfficientFormerMeta4D(keras.layers.Layer): - def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs): - super().__init__(**kwargs) - pool_size = config.pool_size if config.pool_size is not None else 3 - self.token_mixer = TFEfficientFormerPooling(pool_size=pool_size, name="token_mixer") - self.dim = dim - mlp_hidden_dim = int(dim * config.mlp_expansion_ratio) - self.mlp = TFEfficientFormerConvMlp( - config=config, in_features=dim, hidden_features=mlp_hidden_dim, drop=config.hidden_dropout_prob, name="mlp" - ) - - self.drop_path = ( - TFEfficientFormerDropPath(drop_path, name="drop_path") - if drop_path > 0.0 - else keras.layers.Activation("linear", name="drop_path") - ) - self.config = config - - def build(self, input_shape=None): - self.layer_scale_1 = None - self.layer_scale_2 = None - - if self.config.use_layer_scale: - self.layer_scale_1 = self.add_weight( - shape=(self.dim), - initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value), - trainable=True, - name="layer_scale_1", - ) - self.layer_scale_2 = self.add_weight( - shape=(self.dim), - initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value), - trainable=True, - name="layer_scale_2", - ) - - if self.built: - return - self.built = True - if getattr(self, "token_mixer", None) is not None: - with tf.name_scope(self.token_mixer.name): - self.token_mixer.build(None) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "drop_path", None) is not None: - with tf.name_scope(self.drop_path.name): - self.drop_path.build(None) - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tuple[tf.Tensor]: - outputs = self.token_mixer(hidden_states) - - if self.config.use_layer_scale: - layer_output = hidden_states + self.drop_path( - tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * outputs, - training=training, - ) - - layer_output = layer_output + self.drop_path( - tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0) - * self.mlp(hidden_state=layer_output, training=training), - training=training, - ) - - else: - layer_output = hidden_states + self.drop_path(outputs, training=training) - layer_output = layer_output + self.drop_path( - self.mlp(hidden_state=layer_output, training=training), training=training - ) - - return layer_output - - -class TFEfficientFormerMeta4DLayers(keras.layers.Layer): - def __init__(self, config: EfficientFormerConfig, stage_idx: int, **kwargs): - super().__init__(**kwargs) - num_layers = ( - config.depths[stage_idx] if stage_idx != -1 else config.depths[stage_idx] - config.num_meta3d_blocks - ) - drop_paths = [ - config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers) - ] - - self.blocks = [ - TFEfficientFormerMeta4D( - config=config, dim=config.hidden_sizes[stage_idx], drop_path=drop_paths[i], name=f"blocks.{i}" - ) - for i in range(len(drop_paths)) - ] - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tuple[tf.Tensor]: - for layer_module in self.blocks: - hidden_states = layer_module(hidden_states=hidden_states, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "blocks", None) is not None: - for layer in self.blocks: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFEfficientFormerIntermediateStage(keras.layers.Layer): - def __init__(self, config: EfficientFormerConfig, index: int, **kwargs): - super().__init__(**kwargs) - self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=index, name="meta4D_layers") - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tuple[tf.Tensor]: - hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "meta4D_layers", None) is not None: - with tf.name_scope(self.meta4D_layers.name): - self.meta4D_layers.build(None) - - -class TFEfficientFormerLastStage(keras.layers.Layer): - def __init__(self, config: EfficientFormerConfig, **kwargs): - super().__init__(**kwargs) - self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=-1, name="meta4D_layers") - self.flat = TFEfficientFormerFlat(name="flat") - self.meta3D_layers = TFEfficientFormerMeta3DLayers(config, name="meta3D_layers") - - def call( - self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False - ) -> tuple[tf.Tensor]: - hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training) - hidden_states = self.flat(hidden_states=hidden_states) - hidden_states = self.meta3D_layers( - hidden_states=hidden_states, output_attentions=output_attentions, training=training - ) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "meta4D_layers", None) is not None: - with tf.name_scope(self.meta4D_layers.name): - self.meta4D_layers.build(None) - if getattr(self, "flat", None) is not None: - with tf.name_scope(self.flat.name): - self.flat.build(None) - if getattr(self, "meta3D_layers", None) is not None: - with tf.name_scope(self.meta3D_layers.name): - self.meta3D_layers.build(None) - - -class TFEfficientFormerEncoder(keras.layers.Layer): - def __init__(self, config: EfficientFormerConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - num_intermediate_stages = len(config.depths) - 1 - downsamples = [ - config.downsamples[i] or config.hidden_sizes[i] != config.hidden_sizes[i + 1] - for i in range(num_intermediate_stages) - ] - - intermediate_stages = [] - layer_count = -1 - for i in range(num_intermediate_stages): - layer_count += 1 - intermediate_stages.append( - TFEfficientFormerIntermediateStage(config, i, name=f"intermediate_stages.{layer_count}") - ) - if downsamples[i]: - layer_count += 1 - intermediate_stages.append( - TFEfficientFormerPatchEmbeddings( - config, - config.hidden_sizes[i], - config.hidden_sizes[i + 1], - name=f"intermediate_stages.{layer_count}", - ) - ) - self.intermediate_stages = intermediate_stages - self.last_stage = TFEfficientFormerLastStage(config, name="last_stage") - - def call( - self, - hidden_states: tf.Tensor, - output_hidden_states: bool, - output_attentions: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutput: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - for layer_module in self.intermediate_stages: - hidden_states = layer_module(hidden_states, training=training) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_output = self.last_stage(hidden_states, output_attentions=output_attentions, training=training) - - if output_attentions: - all_self_attentions = all_self_attentions + layer_output[1:] - - if output_hidden_states: - all_hidden_states = all_hidden_states + (layer_output[0],) - - if not return_dict: - return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=layer_output[0], - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "last_stage", None) is not None: - with tf.name_scope(self.last_stage.name): - self.last_stage.build(None) - for layer in self.intermediate_stages: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFEfficientFormerMainLayer(keras.layers.Layer): - config_class = EfficientFormerConfig - - def __init__(self, config: EfficientFormerConfig, **kwargs) -> None: - super().__init__(**kwargs) - self.config = config - - self.patch_embed = TFEfficientFormerConvStem(config, config.hidden_sizes[0], name="patch_embed") - self.encoder = TFEfficientFormerEncoder(config, name="encoder") - self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - - @unpack_inputs - def call( - self, - pixel_values: Optional[tf.Tensor] = None, - output_attentions: Optional[tf.Tensor] = None, - output_hidden_states: Optional[tf.Tensor] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[TFBaseModelOutput, tuple[tf.Tensor, ...]]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - # When running on CPU, keras.layers.Conv2D and keras.layers.AveragePool2D do not - # support channels first NCHW format. A number of blocks contain both. - # So change the input format from (batch_size, num_channels, height, width) to - # (batch_size, height, width, num_channels) here. - # shape = (batch_size, in_height, in_width, in_channels=num_channels) - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - embedding_output = self.patch_embed(pixel_values, training=training) - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output, training=training) - - # Change the hidden states from (batch_size, height, width, num_channels) to - # (batch_size, num_channels, height, width). - # The hidden states are in (batch_size, height, width, num_channels) - # shape after all stages except the MB3D blocks. - if output_hidden_states: - hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1][:-1]) + ( - encoder_outputs[1][-1], - ) - - if not return_dict: - head_outputs = (sequence_output,) - return head_outputs + encoder_outputs[1:] - - return TFBaseModelOutput( - last_hidden_state=sequence_output, - hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "patch_embed", None) is not None: - with tf.name_scope(self.patch_embed.name): - self.patch_embed.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, self.config.hidden_sizes[-1]]) - - -class TFEfficientFormerPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = EfficientFormerConfig - base_model_prefix = "efficientformer" - main_input_name = "pixel_values" - - -EFFICIENTFORMER_START_DOCSTRING = r""" - This model is a TensorFlow - [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer). Use it as a regular - TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior. - - - Parameters: - config ([`EfficientFormerConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -EFFICIENTFORMER_INPUTS_DOCSTRING = r""" - Args: - pixel_values ((`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`EfficientFormerImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare EfficientFormer Model transformer outputting raw hidden-states without any specific head on top.", - EFFICIENTFORMER_START_DOCSTRING, -) -class TFEfficientFormerModel(TFEfficientFormerPreTrainedModel): - def __init__(self, config: EfficientFormerConfig, **kwargs) -> None: - super().__init__(config, **kwargs) - - self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPooling, - config_class=_CONFIG_FOR_DOC, - modality="vision", - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - def call( - self, - pixel_values: Optional[tf.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[tuple, TFBaseModelOutput]: - outputs = self.efficientformer( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "efficientformer", None) is not None: - with tf.name_scope(self.efficientformer.name): - self.efficientformer.build(None) - - -@add_start_docstrings( - """ - EfficientFormer Model transformer with an image classification head on top of pooled last hidden state, e.g. for - ImageNet. - """, - EFFICIENTFORMER_START_DOCSTRING, -) -class TFEfficientFormerForImageClassification(TFEfficientFormerPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: EfficientFormerConfig): - super().__init__(config) - - self.num_labels = config.num_labels - self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer") - - # Classifier head - self.classifier = ( - keras.layers.Dense(config.num_labels, name="classifier") - if config.num_labels > 0 - else keras.layers.Activation("linear", name="classifier") - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=TFImageClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, - ) - def call( - self, - pixel_values: Optional[tf.Tensor] = None, - labels: Optional[tf.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[tf.Tensor, TFImageClassifierOutput]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.efficientformer( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - - logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2)) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFImageClassifierOutput( - loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "efficientformer", None) is not None: - with tf.name_scope(self.efficientformer.name): - self.efficientformer.build(None) - if getattr(self, "classifier", None) is not None: - if hasattr(self.classifier, "name"): - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_sizes[-1]]) - - -@dataclass -class TFEfficientFormerForImageClassificationWithTeacherOutput(ModelOutput): - """ - Args: - Output type of [`EfficientFormerForImageClassificationWithTeacher`]. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Prediction scores as the average of the cls_logits and distillation logits. - cls_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the - class token). - distillation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the - distillation token). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when - `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus - the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when - `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - """ - - logits: Optional[tf.Tensor] = None - cls_logits: Optional[tf.Tensor] = None - distillation_logits: Optional[tf.Tensor] = None - hidden_states: Optional[tuple[tf.Tensor]] = None - attentions: Optional[tuple[tf.Tensor]] = None - - -@add_start_docstrings( - """ - EfficientFormer Model transformer with image classification heads on top (a linear layer on top of the final hidden - state and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. - - .. warning:: - This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet - supported. - """, - EFFICIENTFORMER_START_DOCSTRING, -) -class TFEfficientFormerForImageClassificationWithTeacher(TFEfficientFormerPreTrainedModel): - def __init__(self, config: EfficientFormerConfig) -> None: - super().__init__(config) - - self.num_labels = config.num_labels - self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer") - - # Classifier heads - self.classifier = ( - keras.layers.Dense(config.num_labels, name="classifier") - if config.num_labels > 0 - else keras.layers.Activation("linear", name="classifier") - ) - self.distillation_classifier = ( - keras.layers.Dense(config.num_labels, name="distillation_classifier") - if config.num_labels > 0 - else keras.layers.Activation("linear", name="distillation_classifier") - ) - - @unpack_inputs - @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=TFEfficientFormerForImageClassificationWithTeacherOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, - ) - def call( - self, - pixel_values: Optional[tf.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[tuple, TFEfficientFormerForImageClassificationWithTeacherOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if training: - raise Exception( - "This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet supported." - ) - - outputs = self.efficientformer( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - - cls_logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2)) - distillation_logits = self.distillation_classifier(tf.reduce_mean(sequence_output, axis=-2)) - logits = (cls_logits + distillation_logits) / 2 - - if not return_dict: - output = (logits, cls_logits, distillation_logits) + outputs[1:] - return output - - return TFEfficientFormerForImageClassificationWithTeacherOutput( - logits=logits, - cls_logits=cls_logits, - distillation_logits=distillation_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "efficientformer", None) is not None: - with tf.name_scope(self.efficientformer.name): - self.efficientformer.build(None) - if getattr(self, "classifier", None) is not None: - if hasattr(self.classifier, "name"): - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_sizes[-1]]) - if getattr(self, "distillation_classifier", None) is not None: - if hasattr(self.distillation_classifier, "name"): - with tf.name_scope(self.distillation_classifier.name): - self.distillation_classifier.build([None, None, self.config.hidden_sizes[-1]]) - - -__all__ = [ - "TFEfficientFormerForImageClassification", - "TFEfficientFormerForImageClassificationWithTeacher", - "TFEfficientFormerModel", - "TFEfficientFormerPreTrainedModel", -] diff --git a/src/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py b/src/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py deleted file mode 100644 index 76b9c9cf328c..000000000000 --- a/src/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,181 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Convert GPTSANJapanese checkpoints from the original repository to pytorch model.""" - -import argparse -import json -import os -from collections import OrderedDict - -import numpy as np -import tensorflow as tf -import torch - - -def convert_tf_gptsan_to_pt(args): - parameter_file = os.path.join(args.tf_model_dir, "parameters.json") - params = json.loads(open(parameter_file).read()) - if not params: - raise ValueError( - f"It seems that the json file at {parameter_file} is empty. Make sure you have a correct json file." - ) - if not args.output.endswith(".pt"): - args.output = args.output + ".pt" - new_state = OrderedDict() - with tf.device("/CPU:0"): - reader = tf.train.load_checkpoint(args.tf_model_dir) - shapes = reader.get_variable_to_shape_map() - for key_name in shapes: - vnp = reader.get_tensor(key_name).astype(np.float16) - if key_name.endswith("/adam_m") or key_name.endswith("/adam_v"): - continue - if key_name.startswith("pasts/"): - if key_name.startswith("pasts/mlp"): - player = int(key_name[9]) - elif key_name.startswith("pasts/out"): - player = 8 - name = "model.sqout.%d.weight" % (player * 2) # enter to nn.Sequential with Tanh, so 2 at a time - state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix - new_state[name] = torch.tensor(state) - elif key_name.startswith("model/moe"): - player = int(key_name[9:].split("/")[0]) - if key_name.endswith("/switch_gating/kernel"): - name = "model.blocks.%d.feed_forward.mlp.router.classifier.weight" % player - state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix - new_state[name] = torch.tensor(state) - elif key_name.endswith("/softmlp/kernel"): - name = "model.blocks.%d.feed_forward.soft_bypass_mlp.weight" % player - state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix - new_state[name] = torch.tensor(state) - elif key_name.endswith("/wo/kernel") or key_name.endswith("/wi/kernel"): - nlayer = key_name[-9:-7] - for i in range(16): - name = "model.blocks.%d.feed_forward.mlp.experts.expert_%d.%s.weight" % (player, i, nlayer) - state = ( - vnp[i].transpose([1, 0]).copy() - ) # In Mesh-Tensorflow, it is one array, so it is divided - new_state[name] = torch.tensor(state) - elif key_name.startswith("model/mlp"): - player = int(key_name[9:].split("/")[0]) - if key_name.endswith("/p1/kernel"): - name = "model.blocks.%d.feed_forward.mlp.wi.weight" % player - state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix - new_state[name] = torch.tensor(state) - elif key_name.endswith("/p1/bias"): - name = "model.blocks.%d.feed_forward.mlp.wi.bias" % player - state = vnp.copy() # same because it is one dimensional - new_state[name] = torch.tensor(state) - elif key_name.endswith("/p2/kernel"): - name = "model.blocks.%d.feed_forward.mlp.wo.weight" % player - state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix - new_state[name] = torch.tensor(state) - elif key_name.endswith("/p2/bias"): - name = "model.blocks.%d.feed_forward.mlp.wo.bias" % player - state = vnp.copy() # same because it is one dimensional - new_state[name] = torch.tensor(state) - elif key_name.startswith("model/ln"): - player = int(key_name[8:].split("/")[0]) - if key_name.endswith("/b"): - name = "model.blocks.%d.feed_forward.norm.bias" % player - state = vnp.copy() # same because it is one dimensional - new_state[name] = torch.tensor(state) - elif key_name.endswith("/g"): - name = "model.blocks.%d.feed_forward.norm.weight" % player - state = vnp.copy() # same because it is one dimensional - new_state[name] = torch.tensor(state) - elif key_name.startswith("model/att"): - player = int(key_name[9:].split("/")[0]) - if key_name.endswith("/qkv/kernel"): - state = vnp.copy() # Compute same dimension as Mesh-tensorflow using einsum - state_q = state[:, 0, :, :] - state_k = state[:, 1, :, :] - state_v = state[:, 2, :, :] - state_q = ( - state_q.reshape([state_q.shape[0], state_q.shape[1] * state_q.shape[2]]) - .transpose([1, 0]) - .copy() - ) # Mesh-Tensorflow is a diagonal matrix - state_k = ( - state_k.reshape([state_k.shape[0], state_k.shape[1] * state_k.shape[2]]) - .transpose([1, 0]) - .copy() - ) # Mesh-Tensorflow is a diagonal matrix - state_v = ( - state_v.reshape([state_v.shape[0], state_v.shape[1] * state_v.shape[2]]) - .transpose([1, 0]) - .copy() - ) # Mesh-Tensorflow is a diagonal matrix - name = "model.blocks.%d.self_attn.self_attn.q_proj.weight" % player - new_state[name] = torch.tensor(state_q) - name = "model.blocks.%d.self_attn.self_attn.k_proj.weight" % player - new_state[name] = torch.tensor(state_k) - name = "model.blocks.%d.self_attn.self_attn.v_proj.weight" % player - new_state[name] = torch.tensor(state_v) - elif key_name.endswith("/o/kernel"): - name = "model.blocks.%d.self_attn.self_attn.out_proj.weight" % player - state = ( - vnp.reshape([vnp.shape[0] * vnp.shape[1], vnp.shape[2]]).transpose([1, 0]).copy() - ) # Mesh-Tensorflow is a diagonal matrix - new_state[name] = torch.tensor(state) - elif key_name.startswith("model/an"): - player = int(key_name[8:].split("/")[0]) - if key_name.endswith("/b"): - name = "model.blocks.%d.self_attn.norm.bias" % player - state = vnp.copy() # same because it is one dimensional - new_state[name] = torch.tensor(state) - elif key_name.endswith("/g"): - name = "model.blocks.%d.self_attn.norm.weight" % player - state = vnp.copy() # same because it is one dimensional - new_state[name] = torch.tensor(state) - elif ( - key_name.startswith("model/wte") - or key_name.startswith("model/wpe") - or key_name.startswith("model/ete") - ): - nlayer = {"wte": "embed_tokens", "wpe": "position_embeddings", "ete": "extra_position_embeddings"}[ - key_name[-3:] - ] - name = "model.%s.weight" % nlayer - state = vnp.copy() # same in embedded - new_state[name] = torch.tensor(state) - if key_name.startswith("model/wte"): - name = "lm_head.weight" - state = vnp.copy() # same in embedded - new_state[name] = torch.tensor(state) - elif key_name.startswith("model/wob"): - name = "final_logits_bias" - state = vnp.copy() # same in embedded - state = state.reshape((1, -1)) - new_state[name] = torch.tensor(state) - elif key_name == "model/dense/kernel": - name = "model.last_project.weight" - state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix - new_state[name] = torch.tensor(state) - elif key_name == "model/dense_1/bias": - name = "model.last_project.bias" - state = vnp.copy() # same because it is one dimensional - new_state[name] = torch.tensor(state) - torch.save(new_state, args.output) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="model converter.", formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument("--tf_model_dir", metavar="PATH", type=str, required=True, help="import model") - parser.add_argument("--output", metavar="PATH", type=str, required=True, help="output model") - args = parser.parse_args() - convert_tf_gptsan_to_pt(args) diff --git a/src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index 2c7b687c4d98..000000000000 --- a/src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,121 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert Transformer XL checkpoint and datasets.""" - -import argparse -import os -import pickle -import sys - -import torch - -from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl -from transformers.models.deprecated.transfo_xl import tokenization_transfo_xl as data_utils -from transformers.models.deprecated.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES -from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging - - -logging.set_verbosity_info() - -# We do this to be able to load python 2 datasets pickles -# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 -data_utils.Vocab = data_utils.TransfoXLTokenizer -data_utils.Corpus = data_utils.TransfoXLCorpus -sys.modules["data_utils"] = data_utils -sys.modules["vocabulary"] = data_utils - - -def convert_transfo_xl_checkpoint_to_pytorch( - tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file -): - if transfo_xl_dataset_file: - # Convert a pre-processed corpus (see original TensorFlow repo) - with open(transfo_xl_dataset_file, "rb") as fp: - corpus = pickle.load(fp, encoding="latin1") - # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) - pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["pretrained_vocab_file"] - print(f"Save vocabulary to {pytorch_vocab_dump_path}") - corpus_vocab_dict = corpus.vocab.__dict__ - torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) - - corpus_dict_no_vocab = corpus.__dict__ - corpus_dict_no_vocab.pop("vocab", None) - pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME - print(f"Save dataset to {pytorch_dataset_dump_path}") - torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) - - if tf_checkpoint_path: - # Convert a pre-trained TensorFlow model - config_path = os.path.abspath(transfo_xl_config_file) - tf_path = os.path.abspath(tf_checkpoint_path) - - print(f"Converting Transformer XL checkpoint from {tf_path} with config at {config_path}.") - # Initialise PyTorch model - if transfo_xl_config_file == "": - config = TransfoXLConfig() - else: - config = TransfoXLConfig.from_json_file(transfo_xl_config_file) - print(f"Building PyTorch model from configuration: {config}") - model = TransfoXLLMHeadModel(config) - - model = load_tf_weights_in_transfo_xl(model, config, tf_path) - # Save pytorch-model - pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) - pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) - print(f"Save PyTorch model to {os.path.abspath(pytorch_weights_dump_path)}") - torch.save(model.state_dict(), pytorch_weights_dump_path) - print(f"Save configuration file to {os.path.abspath(pytorch_config_dump_path)}") - with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: - f.write(config.to_json_string()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--pytorch_dump_folder_path", - default=None, - type=str, - required=True, - help="Path to the folder to store the PyTorch model or dataset/vocab.", - ) - parser.add_argument( - "--tf_checkpoint_path", - default="", - type=str, - help="An optional path to a TensorFlow checkpoint path to be converted.", - ) - parser.add_argument( - "--transfo_xl_config_file", - default="", - type=str, - help=( - "An optional config json file corresponding to the pre-trained BERT model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--transfo_xl_dataset_file", - default="", - type=str, - help="An optional dataset file to be converted in a vocabulary.\n" - "Given the files are in the pickle format, please be wary of passing it files you trust.", - ) - args = parser.parse_args() - convert_transfo_xl_checkpoint_to_pytorch( - args.tf_checkpoint_path, - args.transfo_xl_config_file, - args.pytorch_dump_folder_path, - args.transfo_xl_dataset_file, - ) diff --git a/src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py b/src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py deleted file mode 100644 index 3c7830d63344..000000000000 --- a/src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py +++ /dev/null @@ -1,1128 +0,0 @@ -# coding=utf-8 -# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -TF 2.0 Transformer XL model. -""" - -from __future__ import annotations - -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ....modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ....tf_utils import shape_list, stable_softmax -from ....utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_transfo_xl import TransfoXLConfig -from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "transfo-xl/transfo-xl-wt103" -_CONFIG_FOR_DOC = "TransfoXLConfig" - - -class TFPositionalEmbedding(keras.layers.Layer): - def __init__(self, demb, **kwargs): - super().__init__(**kwargs) - - self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb)) - - def call(self, pos_seq, bsz=None): - self.inv_freq = tf.cast(self.inv_freq, dtype=pos_seq.dtype) - sinusoid_inp = tf.einsum("i,j->ij", pos_seq, self.inv_freq) - pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) - - if bsz is not None: - return tf.tile(pos_emb[:, None, :], [1, bsz, 1]) - else: - return pos_emb[:, None, :] - - -class TFPositionwiseFF(keras.layers.Layer): - def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5, init_std=0.02, **kwargs): - super().__init__(**kwargs) - - self.d_model = d_model - self.d_inner = d_inner - self.dropout = dropout - - self.layer_1 = keras.layers.Dense( - d_inner, kernel_initializer=get_initializer(init_std), activation=tf.nn.relu, name="CoreNet_._0" - ) - self.drop_1 = keras.layers.Dropout(dropout) - self.layer_2 = keras.layers.Dense(d_model, kernel_initializer=get_initializer(init_std), name="CoreNet_._3") - self.drop_2 = keras.layers.Dropout(dropout) - - self.layer_norm = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layer_norm") - - self.pre_lnorm = pre_lnorm - - def call(self, inp, training=False): - if self.pre_lnorm: - # layer normalization + positionwise feed-forward - core_out = self.layer_norm(inp) - core_out = self.layer_1(core_out) - core_out = self.drop_1(core_out, training=training) - core_out = self.layer_2(core_out) - core_out = self.drop_2(core_out, training=training) - - # residual connection - output = core_out + inp - else: - # positionwise feed-forward - core_out = self.layer_1(inp) - core_out = self.drop_1(core_out, training=training) - core_out = self.layer_2(core_out) - core_out = self.drop_2(core_out, training=training) - - # residual connection + layer normalization - output = self.layer_norm(inp + core_out) - - return output - - -class TFRelPartialLearnableMultiHeadAttn(keras.layers.Layer): - def __init__( - self, - n_head, - d_model, - d_head, - dropout, - dropatt=0.0, - pre_lnorm=False, - r_r_bias=None, - r_w_bias=None, - layer_norm_epsilon=1e-5, - init_std=0.02, - output_attentions=False, - **kwargs, - ): - super().__init__(**kwargs) - - self.n_head = n_head - self.d_model = d_model - self.d_head = d_head - self.dropout = dropout - self.output_attentions = output_attentions - - self.qkv_net = keras.layers.Dense( - 3 * n_head * d_head, kernel_initializer=get_initializer(init_std), use_bias=False, name="qkv_net" - ) - - self.drop = keras.layers.Dropout(dropout) - self.dropatt = keras.layers.Dropout(dropatt) - self.o_net = keras.layers.Dense( - d_model, kernel_initializer=get_initializer(init_std), use_bias=False, name="o_net" - ) - - self.layer_norm = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layer_norm") - - self.scale = 1 / (d_head**0.5) - - self.pre_lnorm = pre_lnorm - - if r_r_bias is not None and r_w_bias is not None: # Biases are shared - self.r_r_bias = r_r_bias - self.r_w_bias = r_w_bias - else: - self.r_r_bias = None - self.r_w_bias = None - - self.r_net = keras.layers.Dense( - self.n_head * self.d_head, kernel_initializer=get_initializer(init_std), use_bias=False, name="r_net" - ) - - def build(self, input_shape): - if self.r_r_bias is None or self.r_w_bias is None: # Biases are not shared - self.r_r_bias = self.add_weight( - shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias" - ) - self.r_w_bias = self.add_weight( - shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias" - ) - super().build(input_shape) - - def _rel_shift(self, x): - x_size = shape_list(x) - - x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]]) - x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]]) - x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) - x = tf.reshape(x, x_size) - - return x - - def call(self, w, r, attn_mask, mems, head_mask, output_attentions, training=False): - qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1] - - if mems is not None: - mems = tf.cast(mems, dtype=w.dtype) - cat = tf.concat([mems, w], 0) - if self.pre_lnorm: - w_heads = self.qkv_net(self.layer_norm(cat)) - else: - w_heads = self.qkv_net(cat) - r_head_k = self.r_net(r) - - w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1) - w_head_q = w_head_q[-qlen:] - else: - if self.pre_lnorm: - w_heads = self.qkv_net(self.layer_norm(w)) - else: - w_heads = self.qkv_net(w) - r_head_k = self.r_net(r) - - w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1) - - klen = shape_list(w_head_k)[0] - - w_head_q = tf.reshape(w_head_q, (qlen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head - w_head_k = tf.reshape(w_head_k, (klen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head - w_head_v = tf.reshape(w_head_v, (klen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head - - r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head - - # compute attention score - rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head - AC = tf.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # qlen x klen x bsz x n_head - - rr_head_q = w_head_q + self.r_r_bias - BD = tf.einsum("ibnd,jnd->ijbn", rr_head_q, r_head_k) # qlen x klen x bsz x n_head - BD = self._rel_shift(BD) - - # [qlen x klen x bsz x n_head] - attn_score = AC + BD - attn_score = attn_score * self.scale - - # compute attention probability - if attn_mask is not None: - attn_mask_t = attn_mask[:, :, None, None] - attn_mask_t = tf.cast(attn_mask_t, dtype=attn_score.dtype) - attn_score = attn_score * (1.0 - attn_mask_t) - 1e30 * attn_mask_t - - # [qlen x klen x bsz x n_head] - attn_prob = stable_softmax(attn_score, axis=1) - attn_prob = self.dropatt(attn_prob, training=training) - - # Mask heads if we want to - if head_mask is not None: - attn_prob = attn_prob * head_mask - - # compute attention vector - attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v) - - # [qlen x bsz x n_head x d_head] - attn_vec_sizes = shape_list(attn_vec) - attn_vec = tf.reshape(attn_vec, (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head)) - - # linear projection - attn_out = self.o_net(attn_vec) - attn_out = self.drop(attn_out, training=training) - - if self.pre_lnorm: - # residual connection - outputs = [w + attn_out] - else: - # residual connection + layer normalization - outputs = [self.layer_norm(w + attn_out)] - - if output_attentions: - outputs.append(attn_prob) - - return outputs - - -class TFRelPartialLearnableDecoderLayer(keras.layers.Layer): - def __init__( - self, - n_head, - d_model, - d_head, - d_inner, - dropout, - dropatt=0.0, - pre_lnorm=False, - r_w_bias=None, - r_r_bias=None, - layer_norm_epsilon=1e-5, - init_std=0.02, - output_attentions=False, - **kwargs, - ): - super().__init__(**kwargs) - - self.dec_attn = TFRelPartialLearnableMultiHeadAttn( - n_head, - d_model, - d_head, - dropout, - dropatt=dropatt, - pre_lnorm=pre_lnorm, - r_w_bias=r_w_bias, - r_r_bias=r_r_bias, - init_std=init_std, - layer_norm_epsilon=layer_norm_epsilon, - output_attentions=output_attentions, - name="dec_attn", - ) - self.pos_ff = TFPositionwiseFF( - d_model, - d_inner, - dropout, - pre_lnorm=pre_lnorm, - init_std=init_std, - layer_norm_epsilon=layer_norm_epsilon, - name="pos_ff", - ) - - def call(self, dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=False): - attn_outputs = self.dec_attn(dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=training) - ff_output = self.pos_ff(attn_outputs[0], training=training) - - outputs = [ff_output] + attn_outputs[1:] - - return outputs - - -class TFTransfoEmbeddings(keras.layers.Layer): - def __init__(self, vocab_size, emb_size, init_std, **kwargs): - super().__init__(**kwargs) - - self.vocab_size = vocab_size - self.emb_size = emb_size - self.init_std = init_std - - def build(self, input_shape): - self.weight = self.add_weight( - shape=(self.vocab_size, self.emb_size), - initializer=get_initializer(self.init_std), - name="embeddings", - ) - - super().build(input_shape) - - def call(self, inputs): - return tf.gather(self.weight, inputs) - - -class TFAdaptiveEmbedding(keras.layers.Layer): - def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, init_std=0.02, sample_softmax=False, **kwargs): - super().__init__(**kwargs) - - self.n_token = n_token - self.d_embed = d_embed - self.init_std = init_std - - self.cutoffs = cutoffs + [n_token] - self.div_val = div_val - self.d_proj = d_proj - - self.emb_scale = d_proj**0.5 - - self.cutoff_ends = [0] + self.cutoffs - - self.emb_layers = [] - self.emb_projs = [] - - if div_val == 1: - raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint - else: - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - d_emb_i = d_embed // (div_val**i) - self.emb_layers.append( - TFTransfoEmbeddings( - r_idx - l_idx, - d_emb_i, - init_std, - name=f"emb_layers_._{i}", - ) - ) - - def build(self, input_shape): - for i in range(len(self.cutoffs)): - d_emb_i = self.d_embed // (self.div_val**i) - self.emb_projs.append( - self.add_weight( - shape=(d_emb_i, self.d_proj), - initializer=get_initializer(self.init_std), - trainable=True, - name=f"emb_projs_._{i}", - ) - ) - - super().build(input_shape) - - def call(self, inp): - if self.div_val == 1: - raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint - else: - inp_flat = tf.reshape(inp, (-1,)) - emb_flat = tf.zeros([shape_list(inp_flat)[0], self.d_proj]) - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - - mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) - - inp_i = tf.boolean_mask(inp_flat, mask_i) - l_idx - emb_i = self.emb_layers[i](inp_i) - emb_i = tf.einsum("id,de->ie", emb_i, self.emb_projs[i]) - - mask_idx = tf.where(mask_i) - scatter = tf.scatter_nd(mask_idx, emb_i, shape_list(emb_flat)) - emb_flat = tf.cast(emb_flat, dtype=scatter.dtype) - emb_flat += scatter - - embed_shape = shape_list(inp) + [self.d_proj] - embed = tf.reshape(emb_flat, embed_shape) - - embed *= self.emb_scale - - return embed - - -@keras_serializable -class TFTransfoXLMainLayer(keras.layers.Layer): - config_class = TransfoXLConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - self.return_dict = config.use_return_dict - - self.n_token = config.vocab_size - - self.d_embed = config.d_embed - self.d_model = config.d_model - self.n_head = config.n_head - self.d_head = config.d_head - self.untie_r = config.untie_r - - self.word_emb = TFAdaptiveEmbedding( - config.vocab_size, - config.d_embed, - config.d_model, - config.cutoffs, - div_val=config.div_val, - init_std=config.init_std, - name="word_emb", - ) - - self.drop = keras.layers.Dropout(config.dropout) - - self.n_layer = config.n_layer - self.mem_len = config.mem_len - self.attn_type = config.attn_type - - self.layers = [] - if config.attn_type == 0: # the default attention - for i in range(config.n_layer): - self.layers.append( - TFRelPartialLearnableDecoderLayer( - config.n_head, - config.d_model, - config.d_head, - config.d_inner, - config.dropout, - dropatt=config.dropatt, - pre_lnorm=config.pre_lnorm, - r_w_bias=None if self.untie_r else self.r_w_bias, - r_r_bias=None if self.untie_r else self.r_r_bias, - layer_norm_epsilon=config.layer_norm_epsilon, - init_std=config.init_std, - output_attentions=self.output_attentions, - name=f"layers_._{i}", - ) - ) - else: # learnable embeddings and absolute embeddings - raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint - - self.same_length = config.same_length - self.clamp_len = config.clamp_len - - if self.attn_type == 0: # default attention - self.pos_emb = TFPositionalEmbedding(self.d_model, name="pos_emb") - else: # learnable embeddings and absolute embeddings - raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint - - def build(self, input_shape): - if not self.untie_r: - self.r_w_bias = self.add_weight( - shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias" - ) - self.r_r_bias = self.add_weight( - shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias" - ) - super().build(input_shape) - - def get_input_embeddings(self): - return self.word_emb - - def set_input_embeddings(self, value): - raise NotImplementedError - - def backward_compatible(self): - self.sample_softmax = -1 - - def reset_memory_length(self, mem_len): - self.mem_len = mem_len - - def _prune_heads(self, heads): - raise NotImplementedError - - def init_mems(self, bsz): - if self.mem_len > 0: - mems = [] - for i in range(self.n_layer): - empty = tf.zeros([self.mem_len, bsz, self.d_model]) - mems.append(empty) - - return mems - else: - return None - - def _update_mems(self, hids, mems, mlen, qlen): - # does not deal with None - if mems is None: - return None - - # mems is not None - assert len(hids) == len(mems), "len(hids) != len(mems)" - - # There are `mlen + qlen` steps that can be cached into mems - new_mems = [] - end_idx = mlen + tf.math.maximum(0, qlen) - beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len)) - for i in range(len(hids)): - mems[i] = tf.cast(mems[i], dtype=hids[i].dtype) - cat = tf.concat([mems[i], hids[i]], axis=0) - tf.stop_gradient(cat) - new_mems.append(cat[beg_idx:end_idx]) - - return new_mems - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - mems: list[tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ): - # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library - # so we transpose here from shape [bsz, len] to shape [len, bsz] - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_ids = tf.transpose(input_ids, perm=(1, 0)) - qlen, bsz = shape_list(input_ids) - elif inputs_embeds is not None: - inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2)) - qlen, bsz = shape_list(inputs_embeds)[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if mems is None: - mems = self.init_mems(bsz) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) - # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.n_layer - - if inputs_embeds is not None: - word_emb = inputs_embeds - else: - word_emb = self.word_emb(input_ids) - - mlen = shape_list(mems[0])[0] if mems is not None else 0 - klen = mlen + qlen - - # Compute decoder attention mask - all_ones = tf.ones([qlen, klen], dtype=tf.int32) - upper_mask = 1 - tf.linalg.band_part(tf.ones([qlen, klen], dtype=tf.int32), -1, mlen) - if self.same_length: - mask_len = klen - self.mem_len - mask_shift_len = qlen - tf.nn.relu(mask_len) # Lazy clamping of negatives to zero - - # Use an indicator variable instead of a conditional to keep the compiler happy - lower_mask = tf.linalg.band_part(all_ones, -1, 0) - ( - tf.linalg.band_part(all_ones, mask_shift_len - 1, 0) * tf.cast(mask_shift_len != 0, tf.int32) - ) - dec_attn_mask = upper_mask + lower_mask - else: - dec_attn_mask = upper_mask - - hids = [] - attentions = [] if output_attentions else None - if self.attn_type == 0: # default - pos_seq = tf.range(klen - 1, -1, -1.0) - if self.clamp_len > 0: - pos_seq = tf.minimum(pos_seq, self.clamp_len) - pos_emb = self.pos_emb(pos_seq) - - core_out = self.drop(word_emb, training=training) - pos_emb = self.drop(pos_emb, training=training) - - for i, layer in enumerate(self.layers): - hids.append(core_out) - mems_i = None if mems is None else mems[i] - layer_outputs = layer( - core_out, - pos_emb, - dec_attn_mask, - mems_i, - head_mask[i], - output_attentions, - training=training, - ) - core_out = layer_outputs[0] - if output_attentions: - attentions.append(layer_outputs[1]) - else: # learnable embeddings and absolute embeddings - raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint - - core_out = self.drop(core_out, training=training) - - new_mems = self._update_mems(hids, mems, mlen, qlen) - - # We transpose back here to shape [bsz, len, hidden_dim] - core_out = tf.transpose(core_out, perm=(1, 0, 2)) - - if output_hidden_states: - # Transpose to library standard shape [bsz, len, hidden_dim] and add last layer - hids = tuple(tf.transpose(t, perm=(1, 0, 2)) for t in hids) - hids = hids + (core_out,) - else: - hids = None - if output_attentions: - # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] - attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) - - if not return_dict: - return tuple(v for v in [core_out, new_mems, hids, attentions] if v is not None) - - return TFTransfoXLModelOutput( - last_hidden_state=core_out, - mems=new_mems, - hidden_states=hids, - attentions=attentions, - ) - - -class TFTransfoXLPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = TransfoXLConfig - base_model_prefix = "transformer" - - -@dataclass -class TFTransfoXLModelOutput(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - mems (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems` - input) to speed up sequential decoding. The token ids which have their past given to this model should not - be passed as input ids as they have already been computed. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: tf.Tensor | None = None - mems: list[tf.Tensor] = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFTransfoXLLMHeadModelOutput(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). - - Args: - losses (`tf.Tensor` of shape *(batch_size, sequence_length-1)*, *optional*, returned when `labels` is provided): - Language modeling losses (not reduced). - prediction_scores (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token after SoftMax). - mems (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems` - input) to speed up sequential decoding. The token ids which have their past given to this model should not - be passed as input ids as they have already been computed. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - prediction_scores: tf.Tensor | None = None - mems: list[tf.Tensor] = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFTransfoXLSequenceClassifierOutputWithPast(ModelOutput): - """ - Base class for outputs of sentence classification models. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - mems (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems` - input) to speed up sequential decoding. The token ids which have their past given to this model should not - be passed as input ids as they have already been computed. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - mems: list[tf.Tensor] = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -TRANSFO_XL_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`TransfoXLConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -TRANSFO_XL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - mems (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see - `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems - given to this model should not be passed as `input_ids` as they have already been computed. - head_mask (`tf.Tensor` or `Numpy array` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", - TRANSFO_XL_START_DOCSTRING, -) -class TFTransfoXLModel(TFTransfoXLPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFTransfoXLMainLayer(config, name="transformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTransfoXLModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - mems: list[tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFTransfoXLModelOutput | tuple[tf.Tensor]: - outputs = self.transformer( - input_ids=input_ids, - mems=mems, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - -@add_start_docstrings( - """ - The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive - input embeddings) - """, - TRANSFO_XL_START_DOCSTRING, -) -class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.transformer = TFTransfoXLMainLayer(config, name="transformer") - self.sample_softmax = config.sample_softmax - assert self.sample_softmax <= 0, ( - "Sampling from the softmax is not implemented yet. Please look at issue: #3310:" - " https://github.com/huggingface/transformers/issues/3310" - ) - - self.crit = TFAdaptiveSoftmaxMask( - config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit" - ) - - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError() - - def get_output_embeddings(self): - """Double-check if you are using adaptive softmax.""" - if len(self.crit.out_layers) > 0: - return self.crit.out_layers[-1] - return None - - def reset_memory_length(self, mem_len): - self.transformer.reset_memory_length(mem_len) - - def init_mems(self, bsz): - return self.transformer.init_mems(bsz) - - @unpack_inputs - @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTransfoXLLMHeadModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - mems: list[tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFTransfoXLLMHeadModelOutput | tuple[tf.Tensor]: - if input_ids is not None: - bsz, tgt_len = shape_list(input_ids)[:2] - else: - bsz, tgt_len = shape_list(inputs_embeds)[:2] - - transformer_outputs = self.transformer( - input_ids, - mems, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict, - training=training, - ) - - last_hidden = transformer_outputs[0] - pred_hid = last_hidden[:, -tgt_len:] - - softmax_output = self.crit(pred_hid, labels, training=training) - prediction_scores = softmax_output if labels is None else () - - if not return_dict: - return (prediction_scores,) + transformer_outputs[1:] - - return TFTransfoXLLMHeadModelOutput( - prediction_scores=prediction_scores, - mems=transformer_outputs.mems, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **model_kwargs): - inputs = {} - - # if past is defined in model kwargs then use it for faster decoding - if past_key_values: - input_ids = tf.expand_dims(input_ids[:, -1], axis=-1) - else: - input_ids = input_ids - - return inputs - - # Adapted from the torch tie_weights function - def tf_to_pt_weight_rename(self, tf_weight): - if self.config.tie_word_embeddings and "crit.out_layers" in tf_weight: - return tf_weight, tf_weight.replace("crit.out_layers", "transformer.word_emb.emb_layers") - elif self.config.tie_projs and "crit.out_projs" in tf_weight: - for i, tie_proj in enumerate(self.config.tie_projs): - if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed: - # self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0] - return tf_weight, tf_weight.replace(f"crit.out_projs.{i}", "transformer.word_emb.emb_projs.0") - elif tie_proj and self.config.div_val != 1: - # self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i] - return tf_weight, tf_weight.replace("crit.out_projs", "transformer.word_emb.emb_projs") - else: - return (tf_weight,) - - -@add_start_docstrings( - """ - The Transfo XL Model transformer with a sequence classification head on top (linear layer). - - [`TFTransfoXLForSequenceClassification`] uses the last token in order to do the classification, as other causal - models (e.g. GPT-1,GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - TRANSFO_XL_START_DOCSTRING, -) -class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.score = keras.layers.Dense( - config.num_labels, - kernel_initializer=get_initializer(config.init_range), - name="score", - use_bias=False, - ) - self.transformer = TFTransfoXLMainLayer(config, name="transformer") - - def get_output_embeddings(self): - # Remove after transformers v4.32. Fix this model's `test_model_common_attributes` test too. - logger.warning( - "Sequence classification models do not have output embeddings. `.get_output_embeddings` will be removed " - "in transformers v4.32." - ) - return self.transformer.word_emb - - @unpack_inputs - @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTransfoXLSequenceClassifierOutputWithPast, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - mems: list[tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFTransfoXLSequenceClassifierOutputWithPast: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - mems=mems, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - in_logits = None - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = ( - tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1) - - 1 - ) - sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1) - in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) - else: - sequence_lengths = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - loss = None - - if labels is not None: - if input_ids is not None: - batch_size, sequence_length = shape_list(input_ids)[:2] - else: - batch_size, sequence_length = shape_list(inputs_embeds)[:2] - assert self.config.pad_token_id is not None or batch_size == 1, ( - "Cannot handle batch sizes > 1 if no padding token is defined." - ) - - if not tf.is_tensor(sequence_lengths): - in_logits = logits[0:batch_size, sequence_lengths] - - loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels])) - - pooled_logits = in_logits if in_logits is not None else logits - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTransfoXLSequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - mems=transformer_outputs.mems, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -__all__ = [ - "TFAdaptiveEmbedding", - "TFTransfoXLForSequenceClassification", - "TFTransfoXLLMHeadModel", - "TFTransfoXLMainLayer", - "TFTransfoXLModel", - "TFTransfoXLPreTrainedModel", -] diff --git a/src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl_utilities.py b/src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl_utilities.py deleted file mode 100644 index 48205e06fb20..000000000000 --- a/src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl_utilities.py +++ /dev/null @@ -1,178 +0,0 @@ -# coding=utf-8 -# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -A TF 2.0 Adaptive Softmax for Transformer XL model. -""" - -import tensorflow as tf - -from ....modeling_tf_utils import keras -from ....tf_utils import shape_list - - -class TFAdaptiveSoftmaxMask(keras.layers.Layer): - def __init__(self, vocab_size, d_embed, d_proj, cutoffs, div_val=1, keep_order=False, **kwargs): - super().__init__(**kwargs) - - self.vocab_size = vocab_size - self.d_embed = d_embed - self.d_proj = d_proj - - self.cutoffs = cutoffs + [vocab_size] - self.cutoff_ends = [0] + self.cutoffs - self.div_val = div_val - - self.shortlist_size = self.cutoffs[0] - self.n_clusters = len(self.cutoffs) - 1 - self.head_size = self.shortlist_size + self.n_clusters - self.keep_order = keep_order - - self.out_layers = [] - self.out_projs = [] - - def build(self, input_shape): - if self.n_clusters > 0: - self.cluster_weight = self.add_weight( - shape=(self.n_clusters, self.d_embed), initializer="zeros", trainable=True, name="cluster_weight" - ) - self.cluster_bias = self.add_weight( - shape=(self.n_clusters,), initializer="zeros", trainable=True, name="cluster_bias" - ) - - if self.div_val == 1: - for i in range(len(self.cutoffs)): - if self.d_proj != self.d_embed: - weight = self.add_weight( - shape=(self.d_embed, self.d_proj), - initializer="zeros", - trainable=True, - name=f"out_projs_._{i}", - ) - self.out_projs.append(weight) - else: - self.out_projs.append(None) - weight = self.add_weight( - shape=(self.vocab_size, self.d_embed), - initializer="zeros", - trainable=True, - name=f"out_layers_._{i}_._weight", - ) - bias = self.add_weight( - shape=(self.vocab_size,), - initializer="zeros", - trainable=True, - name=f"out_layers_._{i}_._bias", - ) - self.out_layers.append((weight, bias)) - else: - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - d_emb_i = self.d_embed // (self.div_val**i) - - weight = self.add_weight( - shape=(d_emb_i, self.d_proj), initializer="zeros", trainable=True, name=f"out_projs_._{i}" - ) - self.out_projs.append(weight) - weight = self.add_weight( - shape=(r_idx - l_idx, d_emb_i), - initializer="zeros", - trainable=True, - name=f"out_layers_._{i}_._weight", - ) - bias = self.add_weight( - shape=(r_idx - l_idx,), - initializer="zeros", - trainable=True, - name=f"out_layers_._{i}_._bias", - ) - self.out_layers.append((weight, bias)) - super().build(input_shape) - - @staticmethod - def _logit(x, W, b, proj=None): - y = x - if proj is not None: - y = tf.einsum("ibd,ed->ibe", y, proj) - return tf.einsum("ibd,nd->ibn", y, W) + b - - @staticmethod - def _gather_logprob(logprob, target): - lp_size = shape_list(logprob) - r = tf.range(lp_size[0], dtype=target.dtype) - idx = tf.stack([r, target], 1) - return tf.gather_nd(logprob, idx) - - def call(self, hidden, target, return_mean=True, training=False): - head_logprob = 0 - if self.n_clusters == 0: - output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0]) - if target is not None: - loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output) - out = tf.nn.log_softmax(output, axis=-1) - else: - hidden_sizes = shape_list(hidden) - out = [] - loss = tf.zeros(hidden_sizes[:2]) - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - if target is not None: - mask = (target >= l_idx) & (target < r_idx) - mask_idx = tf.where(mask) - cur_target = tf.boolean_mask(target, mask) - l_idx - - if self.div_val == 1: - cur_W = self.out_layers[0][0][l_idx:r_idx] - cur_b = self.out_layers[0][1][l_idx:r_idx] - else: - cur_W = self.out_layers[i][0] - cur_b = self.out_layers[i][1] - - if i == 0: - cur_W = tf.concat([cur_W, self.cluster_weight], 0) - cur_b = tf.concat([cur_b, self.cluster_bias], 0) - - head_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[0]) - head_logprob = tf.nn.log_softmax(head_logit) - out.append(head_logprob[..., : self.cutoffs[0]]) - if target is not None: - cur_head_logprob = tf.boolean_mask(head_logprob, mask) - cur_logprob = self._gather_logprob(cur_head_logprob, cur_target) - else: - tail_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[i]) - tail_logprob = tf.nn.log_softmax(tail_logit) - cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster - logprob_i = head_logprob[..., cluster_prob_idx, None] + tail_logprob - out.append(logprob_i) - if target is not None: - cur_head_logprob = tf.boolean_mask(head_logprob, mask) - cur_tail_logprob = tf.boolean_mask(tail_logprob, mask) - cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target) - cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1] - if target is not None: - loss += tf.scatter_nd(mask_idx, -cur_logprob, shape_list(loss)) - out = tf.concat(out, axis=-1) - - if target is not None: - if return_mean: - loss = tf.reduce_mean(loss) - # Add the training-time loss value to the layer using `self.add_loss()`. - self.add_loss(loss) - - # Log the loss as a metric (we could log arbitrary metrics, - # including different metrics for training and inference. - self.add_metric(loss, name=self.name, aggregation="mean" if return_mean else "") - - return out diff --git a/src/transformers/models/dinov2/modeling_flax_dinov2.py b/src/transformers/models/dinov2/modeling_flax_dinov2.py deleted file mode 100644 index b9ea2eaa3ebc..000000000000 --- a/src/transformers/models/dinov2/modeling_flax_dinov2.py +++ /dev/null @@ -1,801 +0,0 @@ -# coding=utf-8 -# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax DINOv2 model.""" - -import collections.abc -import math -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward -from .configuration_dinov2 import Dinov2Config - - -DINOV2_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`Dinov2Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -DINOV2_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`Dinov2ImageProcessor.__call__`] - for details. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class FlaxDinov2PatchEmbeddings(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - image_size = self.config.image_size - patch_size = self.config.patch_size - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - - self.num_patches = num_patches - self.num_channels = self.config.num_channels - self.projection = nn.Conv( - self.config.hidden_size, - kernel_size=patch_size, - strides=patch_size, - padding="VALID", - dtype=self.dtype, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - ) - - # Copied from transformers.models.vit.modeling_flax_vit.FlaxViTPatchEmbeddings.__call__ - def __call__(self, pixel_values): - num_channels = pixel_values.shape[-1] - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - embeddings = self.projection(pixel_values) - batch_size, _, _, channels = embeddings.shape - return jnp.reshape(embeddings, (batch_size, -1, channels)) - - -class FlaxDinov2Embeddings(nn.Module): - """Construct the CLS token, position and patch embeddings.""" - - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.cls_token = self.param( - "cls_token", - jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), - (1, 1, self.config.hidden_size), - ) - if self.config.use_mask_token: - self.mask_token = self.param( - "mask_token", - jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), - (1, self.config.hidden_size), - ) - self.patch_embeddings = FlaxDinov2PatchEmbeddings(self.config, dtype=self.dtype) - num_patches = self.patch_embeddings.num_patches - self.position_embeddings = self.param( - "position_embeddings", - jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), - (1, num_patches + 1, self.config.hidden_size), - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def interpolate_pos_encoding(self, config, hidden_states, height, width, position_embeddings): - num_patches = hidden_states.shape[1] - 1 - num_positions = position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: - return position_embeddings - class_pos_embed = position_embeddings[:, 0] - patch_pos_embed = position_embeddings[:, 1:] - dim = hidden_states.shape[-1] - - h = height // config.patch_size - w = width // config.patch_size - height, width = h + 0.1, w + 0.1 - - patch_pos_embed = patch_pos_embed.reshape( - (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) - ) - patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 3, 1, 2)) - target_dtype = patch_pos_embed.dtype - new_height_ratio = jnp.float32(height / math.sqrt(num_positions)) - new_width_ratio = jnp.float32(width / math.sqrt(num_positions)) - - scale = jnp.array([new_height_ratio, new_width_ratio], dtype=jnp.float32) - translation = jnp.array([0.0, 0.0], dtype=jnp.float32) - - patch_pos_embed = jax.image.scale_and_translate( - patch_pos_embed.astype(jnp.float32), - shape=(patch_pos_embed.shape[0], patch_pos_embed.shape[1], h, w), - spatial_dims=(2, 3), - scale=scale, - translation=translation, - method="bicubic", - antialias=False, - ) - patch_pos_embed = patch_pos_embed.astype(target_dtype) - patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((position_embeddings.shape[0], -1, dim)) - patch_pos_embed_expanded = jnp.tile(patch_pos_embed, (hidden_states.shape[0], 1, 1)) - class_pos_embed_expanded = jnp.tile(class_pos_embed, (hidden_states.shape[0], 1, 1)) - - return jnp.concatenate((class_pos_embed_expanded, patch_pos_embed_expanded), axis=1) - - def __call__(self, pixel_values, deterministic=True): - batch_size = pixel_values.shape[0] - target_dtype = self.patch_embeddings.projection.dtype - height, width = pixel_values.shape[1], pixel_values.shape[2] - - embeddings = self.patch_embeddings(pixel_values.astype(target_dtype)) - - cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) - embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) - - embeddings = embeddings + self.interpolate_pos_encoding( - self.config, embeddings, height, width, self.position_embeddings - ) - - embeddings = self.dropout(embeddings, deterministic=deterministic) - return embeddings - - -# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfAttention with ViT->Dinov2 -class FlaxDinov2SelfAttention(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - if self.config.hidden_size % self.config.num_attention_heads != 0: - raise ValueError( - "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:" - " {self.config.num_attention_heads}" - ) - - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" - ), - use_bias=self.config.qkv_bias, - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" - ), - use_bias=self.config.qkv_bias, - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" - ), - use_bias=self.config.qkv_bias, - ) - - def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): - head_dim = self.config.hidden_size // self.config.num_attention_heads - - query_states = self.query(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - value_states = self.value(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - key_states = self.key(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfOutput with ViT->Dinov2 -class FlaxDinov2SelfOutput(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, input_tensor, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTAttention with ViT->Dinov2 -class FlaxDinov2Attention(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.attention = FlaxDinov2SelfAttention(self.config, dtype=self.dtype) - self.output = FlaxDinov2SelfOutput(self.config, dtype=self.dtype) - - def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False): - attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions) - attn_output = attn_outputs[0] - hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_outputs[1],) - - return outputs - - -def ones_with_scale(key, shape, scale, dtype=jnp.float32): - return jnp.ones(shape, dtype) * scale - - -class FlaxDinov2LayerScale(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.lambda1 = self.config.layerscale_value * self.param( - "lambda1", - jax.nn.initializers.ones, - (self.config.hidden_size,), - ) - self.lambda1 = self.lambda1 * self.config.layerscale_value - - def __call__(self, hidden_states): - return self.lambda1 * hidden_states - - -# Copied from transformers.models.beit.modeling_flax_beit.FlaxBeitDropPath with Beit -> Dinov2 -class FlaxDinov2DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - rate: float - - @nn.module.compact - def __call__(self, inputs, deterministic: Optional[bool] = True): - if self.rate == 0.0: - return inputs - keep_prob = 1.0 - self.rate - if deterministic: - return inputs - else: - shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - rng = self.make_rng("droppath") - random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype) - binary_tensor = jnp.floor(random_tensor) - output = inputs / keep_prob * binary_tensor - return output - - -class FlaxDinov2MLP(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.fc1 = nn.Dense( - self.config.hidden_size * self.config.mlp_ratio, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - dtype=self.dtype, - ) - self.fc2 = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - dtype=self.dtype, - ) - if isinstance(self.config.hidden_act, str): - self.act = ACT2FN[self.config.hidden_act] - else: - self.act = self.config.hidden_act - - def __call__(self, hidden_states): - hidden_states = self.fc1(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -class FlaxDinov2SwiGLUFFN(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - hidden_features = int(self.config.hidden_size * self.config.mlp_ratio) - hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 - - self.weights_in = nn.Dense( - 2 * hidden_features, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - dtype=self.dtype, - ) - self.weights_out = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - hidden_states = self.weights_in(hidden_states) - x1, x2 = jnp.split(hidden_states, 2, axis=-1) - hidden = nn.silu(x1) * x2 - return self.weights_out(hidden) - - -class FlaxDinov2Layer(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.norm1 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.attention = FlaxDinov2Attention(self.config, dtype=self.dtype) - self.layer_scale1 = FlaxDinov2LayerScale(self.config, dtype=self.dtype) - self.drop_path = FlaxDinov2DropPath(self.config.drop_path_rate) - self.norm2 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - if self.config.use_swiglu_ffn: - self.mlp = FlaxDinov2SwiGLUFFN(self.config, dtype=self.dtype) - else: - self.mlp = FlaxDinov2MLP(self.config, dtype=self.dtype) - - self.layer_scale2 = FlaxDinov2LayerScale(self.config, dtype=self.dtype) - - def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): - self_attention_outputs = self.attention( - self.norm1(hidden_states), # in Dinov2, layernorm is applied before self-attention - deterministic=deterministic, - output_attentions=output_attentions, - ) - - attention_output = self_attention_outputs[0] - - attention_output = self.layer_scale1(attention_output) - - outputs = self_attention_outputs[1:] - - # first residual connection - hidden_states = self.drop_path(attention_output) + hidden_states - - # in Dinov2, layernorm is also applied after self-attention - layer_output = self.norm2(hidden_states) - layer_output = self.mlp(layer_output) - layer_output = self.layer_scale2(layer_output) - - # second residual connection - layer_output = self.drop_path(layer_output) + hidden_states - - outputs = (layer_output,) + outputs - - return outputs - - -# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTLayerCollection with ViT->Dinov2 -class FlaxDinov2LayerCollection(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxDinov2Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states,) - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTEncoder with ViT->Dinov2 -class FlaxDinov2Encoder(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layer = FlaxDinov2LayerCollection(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.layer( - hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class FlaxDinov2PreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = Dinov2Config - base_model_prefix = "dinov2" - main_input_name = "pixel_values" - module_class: nn.Module = None - - def __init__( - self, - config: Dinov2Config, - input_shape=None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - if input_shape is None: - input_shape = (1, config.image_size, config.image_size, config.num_channels) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - pixel_values = jnp.zeros(input_shape, dtype=self.dtype) - - params_rng, dropout_rng = jax.random.split(rng) - dropout_rng, droppath_rng = jax.random.split(dropout_rng) - rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng} - - random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - pixel_values, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - dropout_rng, droppath_rng = jax.random.split(dropout_rng) - rngs["dropout"] = dropout_rng - rngs["droppath"] = droppath_rng - - return self.module.apply( - {"params": params or self.params}, - jnp.array(pixel_values, dtype=jnp.float32), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) - - -class FlaxDinov2Module(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.embeddings = FlaxDinov2Embeddings(self.config, dtype=self.dtype) - self.encoder = FlaxDinov2Encoder(self.config, dtype=self.dtype) - self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__( - self, - pixel_values, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - hidden_states = self.embeddings(pixel_values, deterministic=deterministic) - - encoder_outputs = self.encoder( - hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output) - pooled_output = sequence_output[:, 0, :] - - if not return_dict: - head_outputs = (sequence_output, pooled_output) - return head_outputs + encoder_outputs[1:] - - return FlaxBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -@add_start_docstrings( - "The bare Dinov2 Model transformer outputting raw hidden-states without any specific head on top.", - DINOV2_START_DOCSTRING, -) -class FlaxDinov2Model(FlaxDinov2PreTrainedModel): - module_class = FlaxDinov2Module - - -FLAX_VISION_MODEL_DOCSTRING = """ - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, FlaxDinov2Model - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") - >>> model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base") - - >>> inputs = image_processor(images=image, return_tensors="np") - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.last_hidden_state - ``` -""" - -overwrite_call_docstring(FlaxDinov2Model, FLAX_VISION_MODEL_DOCSTRING) -append_replace_return_docstrings( - FlaxDinov2Model, output_type=FlaxBaseModelOutputWithPooling, config_class=Dinov2Config -) - - -class FlaxDinov2ForImageClassificationModule(nn.Module): - config: Dinov2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dinov2 = FlaxDinov2Module(config=self.config, dtype=self.dtype) - self.classifier = nn.Dense( - self.config.num_labels, - dtype=self.dtype, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - ) - - def __call__( - self, - pixel_values=None, - deterministic: bool = True, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.dinov2( - pixel_values, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - cls_token = hidden_states[:, 0] - patch_tokens = hidden_states[:, 1:] - linear_input = jnp.concatenate([cls_token, patch_tokens.mean(axis=1)], axis=-1) - - logits = self.classifier(linear_input) - - if not return_dict: - output = (logits,) + outputs[2:] - return output - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state of - the [CLS] token) e.g. for ImageNet. - """, - DINOV2_START_DOCSTRING, -) -class FlaxDinov2ForImageClassification(FlaxDinov2PreTrainedModel): - module_class = FlaxDinov2ForImageClassificationModule - - -FLAX_VISION_CLASSIFICATION_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoImageProcessor, FlaxDinov2ForImageClassification - >>> from PIL import Image - >>> import jax - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer") - >>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer", from_pt=True) - - >>> inputs = image_processor(images=image, return_tensors="np") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) - >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) - ``` -""" - -overwrite_call_docstring(FlaxDinov2ForImageClassification, FLAX_VISION_CLASSIFICATION_DOCSTRING) -append_replace_return_docstrings( - FlaxDinov2ForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=Dinov2Config -) - - -__all__ = ["FlaxDinov2ForImageClassification", "FlaxDinov2Model", "FlaxDinov2PreTrainedModel"] diff --git a/src/transformers/models/distilbert/modeling_flax_distilbert.py b/src/transformers/models/distilbert/modeling_flax_distilbert.py deleted file mode 100644 index fba3dfd9d332..000000000000 --- a/src/transformers/models/distilbert/modeling_flax_distilbert.py +++ /dev/null @@ -1,906 +0,0 @@ -# coding=utf-8 -# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxMaskedLMOutput, - FlaxMultipleChoiceModelOutput, - FlaxQuestionAnsweringModelOutput, - FlaxSequenceClassifierOutput, - FlaxTokenClassifierOutput, -) -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_distilbert import DistilBertConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "distilbert-base-uncased" -_CONFIG_FOR_DOC = "DistilBertConfig" - - -FLAX_DISTILBERT_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`DistilBertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -DISTILBERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def get_angles(pos, i, d_model): - angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model)) - return pos * angle_rates - - -def positional_encoding(position, d_model): - # create the sinusoidal pattern for the positional encoding - angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model) - - # apply sin to even indices in the array; 2i - angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) - - # apply cos to odd indices in the array; 2i+1 - angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) - - pos_encoding = angle_rads[np.newaxis, ...] - - return jnp.array(pos_encoding) - - -class FlaxEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" - - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.config.dim, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - if not self.config.sinusoidal_pos_embds: - self.position_embeddings = nn.Embed( - self.config.max_position_embeddings, - self.config.dim, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - else: - self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim) - self.LayerNorm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.dropout) - - def __call__(self, input_ids, deterministic: bool = True): - # Embed - batch_size, seq_length = input_ids.shape - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) - if not self.config.sinusoidal_pos_embds: - position_ids = jnp.arange(seq_length).astype("i4") - position_ids = jnp.broadcast_to(position_ids, shape=(batch_size, seq_length)) - position_embeds = self.position_embeddings(position_ids.astype("i4")) - else: - position_embeds = self.pos_encoding[:, :seq_length, :] - # explicitly cast the positions here, since self.embed_positions are not registered as parameters - position_embeds = position_embeds.astype(inputs_embeds.dtype) - - # Sum all embeddings - hidden_states = inputs_embeds + position_embeds - - # Layer Norm - hidden_states = self.LayerNorm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxMultiHeadSelfAttention(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.n_heads = self.config.n_heads - self.dim = self.config.dim - self.dropout = nn.Dropout(rate=self.config.attention_dropout) - - if not (self.dim % self.n_heads == 0): - raise ValueError(f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}") - - self.q_lin = nn.Dense( - self.dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.k_lin = nn.Dense( - self.dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.v_lin = nn.Dense( - self.dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.out_lin = nn.Dense( - self.dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - - def __call__( - self, - query, - key, - value, - mask, - deterministic: bool = True, - output_attentions: bool = False, - ): - bs, q_len, dim = query.shape - k_len = key.shape[1] - # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' - # assert key.size() == value.size() - - dim_per_head = self.dim // self.n_heads - - mask_reshp = (bs, 1, 1, k_len) - - def shape(x): - """separate heads""" - return x.reshape(bs, -1, self.n_heads, dim_per_head).transpose(0, 2, 1, 3) - - def unshape(x): - """group heads""" - return x.transpose(0, 2, 1, 3).reshape(bs, -1, self.n_heads * dim_per_head) - - q = shape(self.q_lin(query)) # (bs, n_heads, q_len, dim_per_head) - k = shape(self.k_lin(key)) # (bs, n_heads, k_len, dim_per_head) - v = shape(self.v_lin(value)) # (bs, n_heads, k_len, dim_per_head) - - q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_len, dim_per_head) - scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) # (bs, n_heads, q_len, k_len) - mask = jnp.reshape(mask, mask_reshp) - - mask = mask.astype(scores.dtype) - scores = scores - 1e30 * (1.0 - mask) - - weights = nn.softmax(scores, axis=-1) # (bs, n_heads, q_len, k_len) - weights = self.dropout(weights, deterministic=deterministic) - - context = jnp.matmul(weights, v) # (bs, n_heads, q_len, dim_per_head) - context = unshape(context) # (bs, q_len, dim) - context = self.out_lin(context) # (bs, q_len, dim) - - if output_attentions: - return (context, weights) - else: - return (context,) - - -class FlaxFFN(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout = nn.Dropout(rate=self.config.dropout) - self.chunk_size_feed_forward = self.config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.lin1 = nn.Dense( - self.config.hidden_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.lin2 = nn.Dense( - self.config.dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - - self.activation = ACT2FN[self.config.activation] - - def __call__(self, hidden_states, deterministic: bool = True): - hidden_states = self.lin1(hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.lin2(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxTransformerBlock(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - assert self.config.dim % self.config.n_heads == 0, ( - f"Hidden size {self.config.dim} not dividable by number of heads {self.config.n_heads}" - ) - - self.attention = FlaxMultiHeadSelfAttention(self.config, dtype=self.dtype) - self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) - - self.ffn = FlaxFFN(self.config, dtype=self.dtype) - self.output_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attn_mask, - output_attentions: bool = False, - deterministic: bool = True, - ): - # Self-Attention - sa_output = self.attention( - query=hidden_states, - key=hidden_states, - value=hidden_states, - mask=attn_mask, - output_attentions=output_attentions, - deterministic=deterministic, - ) - if output_attentions: - sa_output, sa_weights = sa_output - else: - assert type(sa_output) is tuple - sa_output = sa_output[0] - sa_output = self.sa_layer_norm(sa_output + hidden_states) - - # Feed Forward Network - ffn_output = self.ffn(sa_output, deterministic=deterministic) - ffn_output = self.output_layer_norm(ffn_output + sa_output) - output = (ffn_output,) - if output_attentions: - output = (sa_weights,) + output - return output - - -class FlaxTransformer(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxTransformerBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.n_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask, - output_attentions: bool = False, - output_hidden_states: bool = False, - deterministic: bool = True, - return_dict: bool = False, - ): - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - for layer_module in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states=hidden_states, - attn_mask=attention_mask, - output_attentions=output_attentions, - deterministic=deterministic, - ) - hidden_states = layer_outputs[-1] - - if output_attentions: - assert len(layer_outputs) == 2 - attentions = layer_outputs[0] - all_attentions = all_attentions + (attentions,) - else: - assert len(layer_outputs) == 1 - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_attentions, all_hidden_states] if v is not None) - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -class FlaxTransformerEncoder(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layer = FlaxTransformer(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - output_attentions: bool = False, - output_hidden_states: bool = False, - deterministic: bool = True, - return_dict: bool = False, - ): - return self.layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - deterministic=deterministic, - return_dict=return_dict, - ) - - -class FlaxDistilBertLMDecoder(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) - - def __call__(self, inputs, kernel): - inputs = jnp.asarray(inputs, self.dtype) - kernel = jnp.asarray(kernel, self.dtype) - y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ()))) - bias = jnp.asarray(self.bias, self.dtype) - y = y + bias - return y - - -class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = DistilBertConfig - base_model_prefix = "distilbert" - module_class: nn.Module = None - - def __init__( - self, - config: DistilBertConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - input_ids, - attention_mask=None, - head_mask=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) - - -class FlaxDistilBertModule(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.embeddings = FlaxEmbeddings(self.config, dtype=self.dtype) - self.transformer = FlaxTransformerEncoder(self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - input_embeds = self.embeddings(input_ids, deterministic=deterministic) - return self.transformer( - hidden_states=input_embeds, - attention_mask=attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -@add_start_docstrings( - "The bare DistilBert Model transformer outputting raw hidden-states without any specific head on top.", - FLAX_DISTILBERT_START_DOCSTRING, -) -class FlaxDistilBertModel(FlaxDistilBertPreTrainedModel): - module_class = FlaxDistilBertModule - - -append_call_sample_docstring(FlaxDistilBertModel, _CHECKPOINT_FOR_DOC, None, _CONFIG_FOR_DOC) - - -class FlaxDistilBertForMaskedLMModule(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.distilbert = FlaxDistilBertModule(self.config, dtype=self.dtype) - self.vocab_transform = nn.Dense( - self.config.dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.vocab_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) - if self.config.tie_word_embeddings: - self.vocab_projector = FlaxDistilBertLMDecoder( - self.config, - dtype=self.dtype, - ) - else: - self.vocab_projector = nn.Dense( - self.config.vocab_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - - def __call__( - self, - input_ids, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - dlbrt_output = self.distilbert( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - deterministic=deterministic, - return_dict=return_dict, - ) - hidden_states = dlbrt_output[0] - prediction_logits = self.vocab_transform(hidden_states) - prediction_logits = ACT2FN[self.config.activation](prediction_logits) - prediction_logits = self.vocab_layer_norm(prediction_logits) - - if self.config.tie_word_embeddings: - shared_embedding = self.distilbert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - prediction_logits = self.vocab_projector(prediction_logits, shared_embedding.T) - else: - prediction_logits = self.vocab_projector(prediction_logits) - - if not return_dict: - output = (prediction_logits,) + dlbrt_output[1:] - return output - - return FlaxMaskedLMOutput( - logits=prediction_logits, - hidden_states=dlbrt_output.hidden_states, - attentions=dlbrt_output.attentions, - ) - - -@add_start_docstrings("""DistilBert Model with a `language modeling` head on top.""", FLAX_DISTILBERT_START_DOCSTRING) -class FlaxDistilBertForMaskedLM(FlaxDistilBertPreTrainedModel): - module_class = FlaxDistilBertForMaskedLMModule - - -append_call_sample_docstring(FlaxDistilBertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC) - - -class FlaxDistilBertForSequenceClassificationModule(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype) - self.pre_classifier = nn.Dense( - self.config.dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.dropout = nn.Dropout(rate=self.config.seq_classif_dropout) - self.classifier = nn.Dense( - self.config.num_labels, - dtype=self.dtype, - ) - - def __call__( - self, - input_ids, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # Model - distilbert_output = self.distilbert( - input_ids, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_state = distilbert_output[0] # (bs, seq_len, dim) - pooled_output = hidden_state[:, 0] # (bs, dim) - pooled_output = self.pre_classifier(pooled_output) # (bs, dim) - pooled_output = ACT2FN["relu"](pooled_output) - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) # (bs, dim) - - if not return_dict: - return (logits,) + distilbert_output[1:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=distilbert_output.hidden_states, - attentions=distilbert_output.attentions, - ) - - -@add_start_docstrings( - """ - DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - FLAX_DISTILBERT_START_DOCSTRING, -) -class FlaxDistilBertForSequenceClassification(FlaxDistilBertPreTrainedModel): - module_class = FlaxDistilBertForSequenceClassificationModule - - -append_call_sample_docstring( - FlaxDistilBertForSequenceClassification, - _CHECKPOINT_FOR_DOC, - FlaxSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxDistilBertForMultipleChoiceModule(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype) - self.pre_classifier = nn.Dense( - self.config.dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.dropout = nn.Dropout(rate=self.config.seq_classif_dropout) - self.classifier = nn.Dense( - 1, - dtype=self.dtype, - ) - - def __call__( - self, - input_ids, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - num_choices = input_ids.shape[1] - input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None - - # Model - outputs = self.distilbert( - input_ids, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_state = outputs[0] - pooled_output = hidden_state[:, 0] - pooled_output = self.pre_classifier(pooled_output) - pooled_output = ACT2FN["relu"](pooled_output) - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) - - reshaped_logits = logits.reshape(-1, num_choices) - - if not return_dict: - return (reshaped_logits,) + outputs[2:] - - return FlaxMultipleChoiceModelOutput( - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and - a softmax) e.g. for RocStories/SWAG tasks. - """, - FLAX_DISTILBERT_START_DOCSTRING, -) -class FlaxDistilBertForMultipleChoice(FlaxDistilBertPreTrainedModel): - module_class = FlaxDistilBertForMultipleChoiceModule - - -overwrite_call_docstring( - FlaxDistilBertForMultipleChoice, DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") -) -append_call_sample_docstring( - FlaxDistilBertForMultipleChoice, - _CHECKPOINT_FOR_DOC, - FlaxMultipleChoiceModelOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxDistilBertForTokenClassificationModule(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.dropout) - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # Model - outputs = self.distilbert( - input_ids, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.classifier(hidden_states) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxTokenClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. - for Named-Entity-Recognition (NER) tasks. - """, - FLAX_DISTILBERT_START_DOCSTRING, -) -class FlaxDistilBertForTokenClassification(FlaxDistilBertPreTrainedModel): - module_class = FlaxDistilBertForTokenClassificationModule - - -append_call_sample_docstring( - FlaxDistilBertForTokenClassification, - _CHECKPOINT_FOR_DOC, - FlaxTokenClassifierOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxDistilBertForQuestionAnsweringModule(nn.Module): - config: DistilBertConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype) - self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) - assert self.config.num_labels == 2 - self.dropout = nn.Dropout(rate=self.config.qa_dropout) - - def __call__( - self, - input_ids, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Model - distilbert_output = self.distilbert( - input_ids, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = distilbert_output[0] - - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.qa_outputs(hidden_states) - start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - return (start_logits, end_logits) + distilbert_output[1:] - - return FlaxQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - hidden_states=distilbert_output.hidden_states, - attentions=distilbert_output.attentions, - ) - - -@add_start_docstrings( - """ - DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a - linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - FLAX_DISTILBERT_START_DOCSTRING, -) -class FlaxDistilBertForQuestionAnswering(FlaxDistilBertPreTrainedModel): - module_class = FlaxDistilBertForQuestionAnsweringModule - - -append_call_sample_docstring( - FlaxDistilBertForQuestionAnswering, - _CHECKPOINT_FOR_DOC, - FlaxQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) - - -__all__ = [ - "FlaxDistilBertForMaskedLM", - "FlaxDistilBertForMultipleChoice", - "FlaxDistilBertForQuestionAnswering", - "FlaxDistilBertForSequenceClassification", - "FlaxDistilBertForTokenClassification", - "FlaxDistilBertModel", - "FlaxDistilBertPreTrainedModel", -] diff --git a/src/transformers/models/distilbert/modeling_tf_distilbert.py b/src/transformers/models/distilbert/modeling_tf_distilbert.py deleted file mode 100644 index a2efa1105c1c..000000000000 --- a/src/transformers/models/distilbert/modeling_tf_distilbert.py +++ /dev/null @@ -1,1146 +0,0 @@ -# coding=utf-8 -# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -TF 2.0 DistilBERT model -""" - -from __future__ import annotations - -import warnings - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_distilbert import DistilBertConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "distilbert-base-uncased" -_CONFIG_FOR_DOC = "DistilBertConfig" - - -class TFEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.config = config - self.dim = config.dim - self.initializer_range = config.initializer_range - self.max_position_embeddings = config.max_position_embeddings - self.LayerNorm = keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.dropout) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.dim], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.dim], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.dim]) - - def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False): - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - final_embeddings = inputs_embeds + position_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -class TFMultiHeadSelfAttention(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.n_heads = config.n_heads - self.dim = config.dim - self.dropout = keras.layers.Dropout(config.attention_dropout) - self.output_attentions = config.output_attentions - - assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}" - - self.q_lin = keras.layers.Dense( - config.dim, kernel_initializer=get_initializer(config.initializer_range), name="q_lin" - ) - self.k_lin = keras.layers.Dense( - config.dim, kernel_initializer=get_initializer(config.initializer_range), name="k_lin" - ) - self.v_lin = keras.layers.Dense( - config.dim, kernel_initializer=get_initializer(config.initializer_range), name="v_lin" - ) - self.out_lin = keras.layers.Dense( - config.dim, kernel_initializer=get_initializer(config.initializer_range), name="out_lin" - ) - - self.pruned_heads = set() - self.config = config - - def prune_heads(self, heads): - raise NotImplementedError - - def call(self, query, key, value, mask, head_mask, output_attentions, training=False): - """ - Parameters: - query: tf.Tensor(bs, seq_length, dim) - key: tf.Tensor(bs, seq_length, dim) - value: tf.Tensor(bs, seq_length, dim) - mask: tf.Tensor(bs, seq_length) - - Returns: - weights: tf.Tensor(bs, n_heads, seq_length, seq_length) Attention weights context: tf.Tensor(bs, - seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` - """ - bs, q_length, dim = shape_list(query) - k_length = shape_list(key)[1] - # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' - # assert key.size() == value.size() - dim_per_head = int(self.dim / self.n_heads) - dim_per_head = tf.cast(dim_per_head, dtype=tf.int32) - mask_reshape = [bs, 1, 1, k_length] - - def shape(x): - """separate heads""" - return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3)) - - def unshape(x): - """group heads""" - return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head)) - - q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) - k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) - v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) - q = tf.cast(q, dtype=tf.float32) - q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) - k = tf.cast(k, dtype=q.dtype) - scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length) - mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) - # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length) - - mask = tf.cast(mask, dtype=scores.dtype) - scores = scores - 1e30 * (1.0 - mask) - weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) - weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) - - # Mask heads if we want to - if head_mask is not None: - weights = weights * head_mask - - context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) - context = unshape(context) # (bs, q_length, dim) - context = self.out_lin(context) # (bs, q_length, dim) - - if output_attentions: - return (context, weights) - else: - return (context,) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "q_lin", None) is not None: - with tf.name_scope(self.q_lin.name): - self.q_lin.build([None, None, self.config.dim]) - if getattr(self, "k_lin", None) is not None: - with tf.name_scope(self.k_lin.name): - self.k_lin.build([None, None, self.config.dim]) - if getattr(self, "v_lin", None) is not None: - with tf.name_scope(self.v_lin.name): - self.v_lin.build([None, None, self.config.dim]) - if getattr(self, "out_lin", None) is not None: - with tf.name_scope(self.out_lin.name): - self.out_lin.build([None, None, self.config.dim]) - - -class TFFFN(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dropout = keras.layers.Dropout(config.dropout) - self.lin1 = keras.layers.Dense( - config.hidden_dim, kernel_initializer=get_initializer(config.initializer_range), name="lin1" - ) - self.lin2 = keras.layers.Dense( - config.dim, kernel_initializer=get_initializer(config.initializer_range), name="lin2" - ) - self.activation = get_tf_activation(config.activation) - self.config = config - - def call(self, input, training=False): - x = self.lin1(input) - x = self.activation(x) - x = self.lin2(x) - x = self.dropout(x, training=training) - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "lin1", None) is not None: - with tf.name_scope(self.lin1.name): - self.lin1.build([None, None, self.config.dim]) - if getattr(self, "lin2", None) is not None: - with tf.name_scope(self.lin2.name): - self.lin2.build([None, None, self.config.hidden_dim]) - - -class TFTransformerBlock(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.n_heads = config.n_heads - self.dim = config.dim - self.hidden_dim = config.hidden_dim - self.dropout = keras.layers.Dropout(config.dropout) - self.activation = config.activation - self.output_attentions = config.output_attentions - - assert config.dim % config.n_heads == 0, ( - f"Hidden size {config.dim} not dividable by number of heads {config.n_heads}" - ) - - self.attention = TFMultiHeadSelfAttention(config, name="attention") - self.sa_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="sa_layer_norm") - - self.ffn = TFFFN(config, name="ffn") - self.output_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm") - self.config = config - - def call(self, x, attn_mask, head_mask, output_attentions, training=False): # removed: src_enc=None, src_len=None - """ - Parameters: - x: tf.Tensor(bs, seq_length, dim) - attn_mask: tf.Tensor(bs, seq_length) - - Outputs: sa_weights: tf.Tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output: - tf.Tensor(bs, seq_length, dim) The output of the transformer block contextualization. - """ - # Self-Attention - sa_output = self.attention(x, x, x, attn_mask, head_mask, output_attentions, training=training) - if output_attentions: - sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) - else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples - # assert type(sa_output) == tuple - sa_output = sa_output[0] - sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) - - # Feed Forward Network - ffn_output = self.ffn(sa_output, training=training) # (bs, seq_length, dim) - ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) - - output = (ffn_output,) - if output_attentions: - output = (sa_weights,) + output - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "sa_layer_norm", None) is not None: - with tf.name_scope(self.sa_layer_norm.name): - self.sa_layer_norm.build([None, None, self.config.dim]) - if getattr(self, "ffn", None) is not None: - with tf.name_scope(self.ffn.name): - self.ffn.build(None) - if getattr(self, "output_layer_norm", None) is not None: - with tf.name_scope(self.output_layer_norm.name): - self.output_layer_norm.build([None, None, self.config.dim]) - - -class TFTransformer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.n_layers = config.n_layers - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - - self.layer = [TFTransformerBlock(config, name=f"layer_._{i}") for i in range(config.n_layers)] - - def call(self, x, attn_mask, head_mask, output_attentions, output_hidden_states, return_dict, training=False): - # docstyle-ignore - """ - Parameters: - x: tf.Tensor(bs, seq_length, dim) Input sequence embedded. - attn_mask: tf.Tensor(bs, seq_length) Attention mask on the sequence. - - Returns: - hidden_state: tf.Tensor(bs, seq_length, dim) - Sequence of hidden states in the last (top) layer - all_hidden_states: tuple[tf.Tensor(bs, seq_length, dim)] - Tuple of length n_layers with the hidden states from each layer. - Optional: only if output_hidden_states=True - all_attentions: tuple[tf.Tensor(bs, n_heads, seq_length, seq_length)] - Tuple of length n_layers with the attention weights from each layer - Optional: only if output_attentions=True - """ - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_state = x - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_state,) - - layer_outputs = layer_module(hidden_state, attn_mask, head_mask[i], output_attentions, training=training) - hidden_state = layer_outputs[-1] - - if output_attentions: - assert len(layer_outputs) == 2 - attentions = layer_outputs[0] - all_attentions = all_attentions + (attentions,) - else: - assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1" - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_state,) - - if not return_dict: - return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFDistilBertMainLayer(keras.layers.Layer): - config_class = DistilBertConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.num_hidden_layers = config.num_hidden_layers - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - - self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings - self.transformer = TFTransformer(config, name="transformer") # Encoder - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, value): - self.embeddings.weight = value - self.embeddings.vocab_size = value.shape[0] - - def _prune_heads(self, heads_to_prune): - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.ones(input_shape) # (bs, seq_length) - - attention_mask = tf.cast(attention_mask, dtype=tf.float32) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.num_hidden_layers - - embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim) - tfmr_output = self.transformer( - embedding_output, - attention_mask, - head_mask, - output_attentions, - output_hidden_states, - return_dict, - training=training, - ) - - return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL # -class TFDistilBertPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = DistilBertConfig - base_model_prefix = "distilbert" - - -DISTILBERT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`DistilBertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -DISTILBERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.", - DISTILBERT_START_DOCSTRING, -) -class TFDistilBertModel(TFDistilBertPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.distilbert = TFDistilBertMainLayer(config, name="distilbert") # Embeddings - - @unpack_inputs - @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - outputs = self.distilbert( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "distilbert", None) is not None: - with tf.name_scope(self.distilbert.name): - self.distilbert.build(None) - - -class TFDistilBertLMHead(keras.layers.Layer): - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.dim = config.dim - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.input_embeddings = input_embeddings - - def build(self, input_shape): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - super().build(input_shape) - - def get_output_embeddings(self): - return self.input_embeddings - - def set_output_embeddings(self, value): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.dim]) - hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -@add_start_docstrings( - """DistilBert Model with a `masked language modeling` head on top.""", - DISTILBERT_START_DOCSTRING, -) -class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModelingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.config = config - - self.distilbert = TFDistilBertMainLayer(config, name="distilbert") - self.vocab_transform = keras.layers.Dense( - config.dim, kernel_initializer=get_initializer(config.initializer_range), name="vocab_transform" - ) - self.act = get_tf_activation(config.activation) - self.vocab_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") - self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector") - - def get_lm_head(self): - return self.vocab_projector - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.vocab_projector.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - distilbert_output = self.distilbert( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = distilbert_output[0] # (bs, seq_length, dim) - prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) - prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim) - prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) - prediction_logits = self.vocab_projector(prediction_logits) - - loss = None if labels is None else self.hf_compute_loss(labels, prediction_logits) - - if not return_dict: - output = (prediction_logits,) + distilbert_output[1:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_logits, - hidden_states=distilbert_output.hidden_states, - attentions=distilbert_output.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "distilbert", None) is not None: - with tf.name_scope(self.distilbert.name): - self.distilbert.build(None) - if getattr(self, "vocab_transform", None) is not None: - with tf.name_scope(self.vocab_transform.name): - self.vocab_transform.build([None, None, self.config.dim]) - if getattr(self, "vocab_layer_norm", None) is not None: - with tf.name_scope(self.vocab_layer_norm.name): - self.vocab_layer_norm.build([None, None, self.config.dim]) - if getattr(self, "vocab_projector", None) is not None: - with tf.name_scope(self.vocab_projector.name): - self.vocab_projector.build(None) - - -@add_start_docstrings( - """ - DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - DISTILBERT_START_DOCSTRING, -) -class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.distilbert = TFDistilBertMainLayer(config, name="distilbert") - self.pre_classifier = keras.layers.Dense( - config.dim, - kernel_initializer=get_initializer(config.initializer_range), - activation="relu", - name="pre_classifier", - ) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.dropout = keras.layers.Dropout(config.seq_classif_dropout) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - distilbert_output = self.distilbert( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_state = distilbert_output[0] # (bs, seq_len, dim) - pooled_output = hidden_state[:, 0] # (bs, dim) - pooled_output = self.pre_classifier(pooled_output) # (bs, dim) - pooled_output = self.dropout(pooled_output, training=training) # (bs, dim) - logits = self.classifier(pooled_output) # (bs, dim) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + distilbert_output[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=distilbert_output.hidden_states, - attentions=distilbert_output.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "distilbert", None) is not None: - with tf.name_scope(self.distilbert.name): - self.distilbert.build(None) - if getattr(self, "pre_classifier", None) is not None: - with tf.name_scope(self.pre_classifier.name): - self.pre_classifier.build([None, None, self.config.dim]) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.dim]) - - -@add_start_docstrings( - """ - DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. - for Named-Entity-Recognition (NER) tasks. - """, - DISTILBERT_START_DOCSTRING, -) -class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.distilbert = TFDistilBertMainLayer(config, name="distilbert") - self.dropout = keras.layers.Dropout(config.dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.distilbert( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "distilbert", None) is not None: - with tf.name_scope(self.distilbert.name): - self.distilbert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and - a softmax) e.g. for RocStories/SWAG tasks. - """, - DISTILBERT_START_DOCSTRING, -) -class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoiceLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.distilbert = TFDistilBertMainLayer(config, name="distilbert") - self.dropout = keras.layers.Dropout(config.seq_classif_dropout) - self.pre_classifier = keras.layers.Dense( - config.dim, - kernel_initializer=get_initializer(config.initializer_range), - activation="relu", - name="pre_classifier", - ) - self.classifier = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward( - DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") - ) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_inputs_embeds = ( - tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - distilbert_output = self.distilbert( - flat_input_ids, - flat_attention_mask, - head_mask, - flat_inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_state = distilbert_output[0] # (bs, seq_len, dim) - pooled_output = hidden_state[:, 0] # (bs, dim) - pooled_output = self.pre_classifier(pooled_output) # (bs, dim) - pooled_output = self.dropout(pooled_output, training=training) # (bs, dim) - logits = self.classifier(pooled_output) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + distilbert_output[1:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=distilbert_output.hidden_states, - attentions=distilbert_output.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "distilbert", None) is not None: - with tf.name_scope(self.distilbert.name): - self.distilbert.build(None) - if getattr(self, "pre_classifier", None) is not None: - with tf.name_scope(self.pre_classifier.name): - self.pre_classifier.build([None, None, self.config.dim]) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.dim]) - - -@add_start_docstrings( - """ - DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a - linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - DISTILBERT_START_DOCSTRING, -) -class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAnsweringLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.distilbert = TFDistilBertMainLayer(config, name="distilbert") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - assert config.num_labels == 2, f"Incorrect number of labels {config.num_labels} instead of 2" - self.dropout = keras.layers.Dropout(config.qa_dropout) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - distilbert_output = self.distilbert( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = distilbert_output[0] # (bs, max_query_len, dim) - hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim) - logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - - loss = None - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + distilbert_output[1:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=distilbert_output.hidden_states, - attentions=distilbert_output.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "distilbert", None) is not None: - with tf.name_scope(self.distilbert.name): - self.distilbert.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.dim]) - - -__all__ = [ - "TFDistilBertForMaskedLM", - "TFDistilBertForMultipleChoice", - "TFDistilBertForQuestionAnswering", - "TFDistilBertForSequenceClassification", - "TFDistilBertForTokenClassification", - "TFDistilBertMainLayer", - "TFDistilBertModel", - "TFDistilBertPreTrainedModel", -] diff --git a/src/transformers/models/dpr/modeling_tf_dpr.py b/src/transformers/models/dpr/modeling_tf_dpr.py deleted file mode 100644 index aef83e6c55fb..000000000000 --- a/src/transformers/models/dpr/modeling_tf_dpr.py +++ /dev/null @@ -1,799 +0,0 @@ -# coding=utf-8 -# Copyright 2018 DPR Authors, The Hugging Face Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""TensorFlow DPR model for Open Domain Question Answering.""" - -from __future__ import annotations - -from dataclasses import dataclass - -import tensorflow as tf - -from ...modeling_tf_outputs import TFBaseModelOutputWithPooling -from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, get_initializer, keras, shape_list, unpack_inputs -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from ..bert.modeling_tf_bert import TFBertMainLayer -from .configuration_dpr import DPRConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "DPRConfig" - - -########## -# Outputs -########## - - -@dataclass -class TFDPRContextEncoderOutput(ModelOutput): - r""" - Class for outputs of [`TFDPRContextEncoder`]. - - Args: - pooler_output (`tf.Tensor` of shape `(batch_size, embeddings_size)`): - The DPR encoder outputs the *pooler_output* that corresponds to the context representation. Last layer - hidden-state of the first token of the sequence (classification token) further processed by a Linear layer. - This output is to be used to embed contexts for nearest neighbors queries with questions embeddings. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - pooler_output: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFDPRQuestionEncoderOutput(ModelOutput): - """ - Class for outputs of [`TFDPRQuestionEncoder`]. - - Args: - pooler_output (`tf.Tensor` of shape `(batch_size, embeddings_size)`): - The DPR encoder outputs the *pooler_output* that corresponds to the question representation. Last layer - hidden-state of the first token of the sequence (classification token) further processed by a Linear layer. - This output is to be used to embed questions for nearest neighbors queries with context embeddings. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - pooler_output: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFDPRReaderOutput(ModelOutput): - """ - Class for outputs of [`TFDPRReaderEncoder`]. - - Args: - start_logits (`tf.Tensor` of shape `(n_passages, sequence_length)`): - Logits of the start index of the span for each passage. - end_logits (`tf.Tensor` of shape `(n_passages, sequence_length)`): - Logits of the end index of the span for each passage. - relevance_logits (`tf.Tensor` of shape `(n_passages, )`): - Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage to answer the - question, compared to all the other passages. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - start_logits: tf.Tensor | None = None - end_logits: tf.Tensor | None = None - relevance_logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -class TFDPREncoderLayer(keras.layers.Layer): - base_model_prefix = "bert_model" - - def __init__(self, config: DPRConfig, **kwargs): - super().__init__(**kwargs) - - # resolve name conflict with TFBertMainLayer instead of TFBertModel - self.bert_model = TFBertMainLayer(config, add_pooling_layer=False, name="bert_model") - self.config = config - - if self.config.hidden_size <= 0: - raise ValueError("Encoder hidden_size can't be zero") - self.projection_dim = config.projection_dim - if self.projection_dim > 0: - self.encode_proj = keras.layers.Dense( - config.projection_dim, kernel_initializer=get_initializer(config.initializer_range), name="encode_proj" - ) - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor, ...]: - outputs = self.bert_model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - pooled_output = sequence_output[:, 0, :] - if self.projection_dim > 0: - pooled_output = self.encode_proj(pooled_output) - - if not return_dict: - return (sequence_output, pooled_output) + outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - @property - def embeddings_size(self) -> int: - if self.projection_dim > 0: - return self.projection_dim - return self.bert_model.config.hidden_size - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bert_model", None) is not None: - with tf.name_scope(self.bert_model.name): - self.bert_model.build(None) - if getattr(self, "encode_proj", None) is not None: - with tf.name_scope(self.encode_proj.name): - self.encode_proj.build(None) - - -class TFDPRSpanPredictorLayer(keras.layers.Layer): - base_model_prefix = "encoder" - - def __init__(self, config: DPRConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.encoder = TFDPREncoderLayer(config, name="encoder") - - self.qa_outputs = keras.layers.Dense( - 2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.qa_classifier = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="qa_classifier" - ) - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = False, - training: bool = False, - ) -> TFDPRReaderOutput | tuple[tf.Tensor, ...]: - # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length - n_passages, sequence_length = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:2] - # feed encoder - outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - # compute logits - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - relevance_logits = self.qa_classifier(sequence_output[:, 0, :]) - - # resize - start_logits = tf.reshape(start_logits, [n_passages, sequence_length]) - end_logits = tf.reshape(end_logits, [n_passages, sequence_length]) - relevance_logits = tf.reshape(relevance_logits, [n_passages]) - - if not return_dict: - return (start_logits, end_logits, relevance_logits) + outputs[2:] - - return TFDPRReaderOutput( - start_logits=start_logits, - end_logits=end_logits, - relevance_logits=relevance_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.encoder.embeddings_size]) - if getattr(self, "qa_classifier", None) is not None: - with tf.name_scope(self.qa_classifier.name): - self.qa_classifier.build([None, None, self.encoder.embeddings_size]) - - -class TFDPRSpanPredictor(TFPreTrainedModel): - base_model_prefix = "encoder" - - def __init__(self, config: DPRConfig, **kwargs): - super().__init__(config, **kwargs) - self.encoder = TFDPRSpanPredictorLayer(config) - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = False, - training: bool = False, - ) -> TFDPRReaderOutput | tuple[tf.Tensor, ...]: - outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - -class TFDPREncoder(TFPreTrainedModel): - base_model_prefix = "encoder" - - def __init__(self, config: DPRConfig, **kwargs): - super().__init__(config, **kwargs) - - self.encoder = TFDPREncoderLayer(config) - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = False, - training: bool = False, - ) -> TFDPRReaderOutput | tuple[tf.Tensor, ...]: - outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - -################## -# PreTrainedModel -################## - - -class TFDPRPretrainedContextEncoder(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = DPRConfig - base_model_prefix = "ctx_encoder" - - -class TFDPRPretrainedQuestionEncoder(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = DPRConfig - base_model_prefix = "question_encoder" - - -class TFDPRPretrainedReader(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = DPRConfig - base_model_prefix = "reader" - - -############### -# Actual Models -############### - - -TF_DPR_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Tensorflow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) - subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to - general usage and behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`DPRConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -TF_DPR_ENCODERS_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be - formatted with [CLS] and [SEP] tokens as follows: - - (a) For sequence pairs (for a pair title+text for example): - - ``` - tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] - token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 - ``` - - (b) For single sequences (for a question for example): - - ``` - tokens: [CLS] the dog is hairy . [SEP] - token_type_ids: 0 0 0 0 0 0 0 - ``` - - DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right - rather than the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - inputs_embeds (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - -TF_DPR_READER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shapes `(n_passages, sequence_length)`): - Indices of input sequence tokens in the vocabulary. It has to be a sequence triplet with 1) the question - and 2) the passages titles and 3) the passages texts To match pretraining, DPR `input_ids` sequence should - be formatted with [CLS] and [SEP] with the format: - - `[CLS] [SEP] [SEP] ` - - DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right - rather than the left. - - Indices can be obtained using [`DPRReaderTokenizer`]. See this class documentation for more details. - attention_mask (`Numpy array` or `tf.Tensor` of shape `(n_passages, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - inputs_embeds (`Numpy array` or `tf.Tensor` of shape `(n_passages, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare DPRContextEncoder transformer outputting pooler outputs as context representations.", - TF_DPR_START_DOCSTRING, -) -class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): - def __init__(self, config: DPRConfig, *args, **kwargs): - super().__init__(config, *args, **kwargs) - self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder") - - def get_input_embeddings(self): - try: - return self.ctx_encoder.bert_model.get_input_embeddings() - except AttributeError: - self.build() - return self.ctx_encoder.bert_model.get_input_embeddings() - - @unpack_inputs - @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFDPRContextEncoderOutput | tuple[tf.Tensor, ...]: - r""" - Return: - - Examples: - - ```python - >>> from transformers import TFDPRContextEncoder, DPRContextEncoderTokenizer - - >>> tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") - >>> model = TFDPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", from_pt=True) - >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="tf")["input_ids"] - >>> embeddings = model(input_ids).pooler_output - ``` - """ - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = ( - tf.ones(input_shape, dtype=tf.dtypes.int32) - if input_ids is None - else (input_ids != self.config.pad_token_id) - ) - if token_type_ids is None: - token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32) - - outputs = self.ctx_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return outputs[1:] - - return TFDPRContextEncoderOutput( - pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "ctx_encoder", None) is not None: - with tf.name_scope(self.ctx_encoder.name): - self.ctx_encoder.build(None) - - -@add_start_docstrings( - "The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.", - TF_DPR_START_DOCSTRING, -) -class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): - def __init__(self, config: DPRConfig, *args, **kwargs): - super().__init__(config, *args, **kwargs) - self.question_encoder = TFDPREncoderLayer(config, name="question_encoder") - - def get_input_embeddings(self): - try: - return self.question_encoder.bert_model.get_input_embeddings() - except AttributeError: - self.build() - return self.question_encoder.bert_model.get_input_embeddings() - - @unpack_inputs - @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFDPRQuestionEncoderOutput | tuple[tf.Tensor, ...]: - r""" - Return: - - Examples: - - ```python - >>> from transformers import TFDPRQuestionEncoder, DPRQuestionEncoderTokenizer - - >>> tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base") - >>> model = TFDPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base", from_pt=True) - >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="tf")["input_ids"] - >>> embeddings = model(input_ids).pooler_output - ``` - """ - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = ( - tf.ones(input_shape, dtype=tf.dtypes.int32) - if input_ids is None - else (input_ids != self.config.pad_token_id) - ) - if token_type_ids is None: - token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32) - - outputs = self.question_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return outputs[1:] - return TFDPRQuestionEncoderOutput( - pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "question_encoder", None) is not None: - with tf.name_scope(self.question_encoder.name): - self.question_encoder.build(None) - - -@add_start_docstrings( - "The bare DPRReader transformer outputting span predictions.", - TF_DPR_START_DOCSTRING, -) -class TFDPRReader(TFDPRPretrainedReader): - def __init__(self, config: DPRConfig, *args, **kwargs): - super().__init__(config, *args, **kwargs) - self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor") - - def get_input_embeddings(self): - try: - return self.span_predictor.encoder.bert_model.get_input_embeddings() - except AttributeError: - self.build() - return self.span_predictor.encoder.bert_model.get_input_embeddings() - - @unpack_inputs - @add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFDPRReaderOutput | tuple[tf.Tensor, ...]: - r""" - Return: - - Examples: - - ```python - >>> from transformers import TFDPRReader, DPRReaderTokenizer - - >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base") - >>> model = TFDPRReader.from_pretrained("facebook/dpr-reader-single-nq-base", from_pt=True) - >>> encoded_inputs = tokenizer( - ... questions=["What is love ?"], - ... titles=["Haddaway"], - ... texts=["'What Is Love' is a song recorded by the artist Haddaway"], - ... return_tensors="tf", - ... ) - >>> outputs = model(encoded_inputs) - >>> start_logits = outputs.start_logits - >>> end_logits = outputs.end_logits - >>> relevance_logits = outputs.relevance_logits - ``` - """ - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32) - - return self.span_predictor( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "span_predictor", None) is not None: - with tf.name_scope(self.span_predictor.name): - self.span_predictor.build(None) - - -__all__ = [ - "TFDPRContextEncoder", - "TFDPRPretrainedContextEncoder", - "TFDPRPretrainedQuestionEncoder", - "TFDPRPretrainedReader", - "TFDPRQuestionEncoder", - "TFDPRReader", -] diff --git a/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index b0abc30cd758..000000000000 --- a/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,79 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert ELECTRA checkpoint.""" - -import argparse - -import torch - -from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator): - # Initialise PyTorch model - config = ElectraConfig.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") - - if discriminator_or_generator == "discriminator": - model = ElectraForPreTraining(config) - elif discriminator_or_generator == "generator": - model = ElectraForMaskedLM(config) - else: - raise ValueError("The discriminator_or_generator argument should be either 'discriminator' or 'generator'") - - # Load weights from tf checkpoint - load_tf_weights_in_electra( - model, config, tf_checkpoint_path, discriminator_or_generator=discriminator_or_generator - ) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - torch.save(model.state_dict(), pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.", - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - parser.add_argument( - "--discriminator_or_generator", - default=None, - type=str, - required=True, - help=( - "Whether to export the generator or the discriminator. Should be a string, either 'discriminator' or " - "'generator'." - ), - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch( - args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.discriminator_or_generator - ) diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py deleted file mode 100644 index 14d845476d62..000000000000 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ /dev/null @@ -1,1614 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Optional - -import flax -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen import partitioning as nn_partitioning -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxMaskedLMOutput, - FlaxMultipleChoiceModelOutput, - FlaxQuestionAnsweringModelOutput, - FlaxSequenceClassifierOutput, - FlaxTokenClassifierOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_electra import ElectraConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator" -_CONFIG_FOR_DOC = "ElectraConfig" - -remat = nn_partitioning.remat - - -@flax.struct.dataclass -class FlaxElectraForPreTrainingOutput(ModelOutput): - """ - Output type of [`ElectraForPreTraining`]. - - Args: - logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: jnp.ndarray = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -ELECTRA_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`ElectraConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -ELECTRA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - head_mask (`numpy.ndarray` of shape `({0})`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - -""" - - -class FlaxElectraEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" - - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.config.embedding_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.position_embeddings = nn.Embed( - self.config.max_position_embeddings, - self.config.embedding_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.token_type_embeddings = nn.Embed( - self.config.type_vocab_size, - self.config.embedding_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.__call__ - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): - # Embed - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) - position_embeds = self.position_embeddings(position_ids.astype("i4")) - token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) - - # Sum all embeddings - hidden_states = inputs_embeds + token_type_embeddings + position_embeds - - # Layer Norm - hidden_states = self.LayerNorm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra -class FlaxElectraSelfAttention(nn.Module): - config: ElectraConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.head_dim = self.config.hidden_size // self.config.num_attention_heads - if self.config.hidden_size % self.config.num_attention_heads != 0: - raise ValueError( - "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " - " : {self.config.num_attention_heads}" - ) - - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) - - @nn.compact - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic=True, - output_attentions: bool = False, - ): - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.query(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.key(key_value_states) - value_states = self.value(key_value_states) - else: - # self_attention - key_states = self.key(hidden_states) - value_states = self.value(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Electra -class FlaxElectraSelfOutput(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, input_tensor, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra -class FlaxElectraAttention(nn.Module): - config: ElectraConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.self = FlaxElectraSelfAttention(self.config, causal=self.causal, dtype=self.dtype) - self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states=None, - init_cache=False, - deterministic=True, - output_attentions: bool = False, - ): - # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) - # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable - # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - attn_outputs = self.self( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=key_value_states, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_outputs[1],) - - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Electra -class FlaxElectraIntermediate(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.activation = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Electra -class FlaxElectraOutput(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states, attention_output, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + attention_output) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Electra -class FlaxElectraLayer(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.attention = FlaxElectraAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) - self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype) - self.output = FlaxElectraOutput(self.config, dtype=self.dtype) - if self.config.add_cross_attention: - self.crossattention = FlaxElectraAttention(self.config, causal=False, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - ): - # Self Attention - attention_outputs = self.attention( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = attention_outputs[0] - - # Cross-Attention Block - if encoder_hidden_states is not None: - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=encoder_hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = cross_attention_outputs[0] - - hidden_states = self.intermediate(attention_output) - hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[1],) - if encoder_hidden_states is not None: - outputs += (cross_attention_outputs[1],) - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Electra -class FlaxElectraLayerCollection(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - if self.gradient_checkpointing: - FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7)) - self.layers = [ - FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - else: - self.layers = [ - FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - # Check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.shape[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for " - f" {head_mask.shape[0]}." - ) - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer( - hidden_states, - attention_mask, - head_mask[i] if head_mask is not None else None, - encoder_hidden_states, - encoder_attention_mask, - init_cache, - deterministic, - output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Electra -class FlaxElectraEncoder(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.layer = FlaxElectraLayerCollection( - self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.layer( - hidden_states, - attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class FlaxElectraGeneratorPredictions(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype) - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = ACT2FN[self.config.hidden_act](hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class FlaxElectraDiscriminatorPredictions(nn.Module): - """Prediction module for the discriminator, made up of two dense layers.""" - - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) - self.dense_prediction = nn.Dense(1, dtype=self.dtype) - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = ACT2FN[self.config.hidden_act](hidden_states) - hidden_states = self.dense_prediction(hidden_states).squeeze(-1) - return hidden_states - - -class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ElectraConfig - base_model_prefix = "electra" - module_class: nn.Module = None - - def __init__( - self, - config: ElectraConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - gradient_checkpointing: bool = False, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing - def enable_gradient_checkpointing(self): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=True, - ) - - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - token_type_ids = jnp.zeros_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - attention_mask = jnp.ones_like(input_ids) - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False - ) - - random_params = module_init_outputs["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length), dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - past_key_values: Optional[dict] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # init input tensors if not passed - if token_type_ids is None: - token_type_ids = jnp.ones_like(input_ids) - - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - if head_mask is None: - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - if self.config.add_cross_attention: - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed - # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be - # changed by FlaxElectraAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - else: - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - ) - - return outputs - - -class FlaxElectraModule(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype) - if self.config.embedding_size != self.config.hidden_size: - self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype) - self.encoder = FlaxElectraEncoder( - self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask: Optional[np.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - embeddings = self.embeddings( - input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic - ) - if hasattr(self, "embeddings_project"): - embeddings = self.embeddings_project(embeddings) - - return self.encoder( - embeddings, - attention_mask, - head_mask=head_mask, - deterministic=deterministic, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -@add_start_docstrings( - "The bare Electra Model transformer outputting raw hidden-states without any specific head on top.", - ELECTRA_START_DOCSTRING, -) -class FlaxElectraModel(FlaxElectraPreTrainedModel): - module_class = FlaxElectraModule - - -append_call_sample_docstring(FlaxElectraModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC) - - -class FlaxElectraTiedDense(nn.Module): - embedding_size: int - dtype: jnp.dtype = jnp.float32 - precision = None - bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.bias = self.param("bias", self.bias_init, (self.embedding_size,)) - - def __call__(self, x, kernel): - x = jnp.asarray(x, self.dtype) - kernel = jnp.asarray(kernel, self.dtype) - y = lax.dot_general( - x, - kernel, - (((x.ndim - 1,), (0,)), ((), ())), - precision=self.precision, - ) - bias = jnp.asarray(self.bias, self.dtype) - return y + bias - - -class FlaxElectraForMaskedLMModule(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.electra = FlaxElectraModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) - if self.config.tie_word_embeddings: - self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) - else: - self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - outputs = self.electra( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - prediction_scores = self.generator_predictions(hidden_states) - - if self.config.tie_word_embeddings: - shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T) - else: - prediction_scores = self.generator_lm_head(prediction_scores) - - if not return_dict: - return (prediction_scores,) + outputs[1:] - - return FlaxMaskedLMOutput( - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings("""Electra Model with a `language modeling` head on top.""", ELECTRA_START_DOCSTRING) -class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel): - module_class = FlaxElectraForMaskedLMModule - - -append_call_sample_docstring(FlaxElectraForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC) - - -class FlaxElectraForPreTrainingModule(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.electra = FlaxElectraModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.electra( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - - logits = self.discriminator_predictions(hidden_states) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxElectraForPreTrainingOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Electra model with a binary classification head on top as used during pretraining for identifying generated tokens. - - It is recommended to load the discriminator checkpoint into that model. - """, - ELECTRA_START_DOCSTRING, -) -class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel): - module_class = FlaxElectraForPreTrainingModule - - -FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxElectraForPreTraining - - >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator") - >>> model = FlaxElectraForPreTraining.from_pretrained("google/electra-small-discriminator") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.logits - ``` -""" - -overwrite_call_docstring( - FlaxElectraForPreTraining, - ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING, -) -append_replace_return_docstrings( - FlaxElectraForPreTraining, output_type=FlaxElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC -) - - -class FlaxElectraForTokenClassificationModule(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.electra = FlaxElectraModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.electra( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.classifier(hidden_states) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxTokenClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Electra model with a token classification head on top. - - Both the discriminator and generator may be loaded into this model. - """, - ELECTRA_START_DOCSTRING, -) -class FlaxElectraForTokenClassification(FlaxElectraPreTrainedModel): - module_class = FlaxElectraForTokenClassificationModule - - -append_call_sample_docstring( - FlaxElectraForTokenClassification, - _CHECKPOINT_FOR_DOC, - FlaxTokenClassifierOutput, - _CONFIG_FOR_DOC, -) - - -def identity(x, **kwargs): - return x - - -class FlaxElectraSequenceSummary(nn.Module): - r""" - Compute a single vector summary of a sequence hidden states. - - Args: - config ([`PretrainedConfig`]): - The config used by the model. Relevant arguments in the config class of the model are (refer to the actual - config class of your model for the default values it uses): - - - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. - - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes - (otherwise to `config.hidden_size`). - - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, - another string or `None` will add no activation. - - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. - - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. - """ - - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.summary = identity - if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj: - if ( - hasattr(self.config, "summary_proj_to_labels") - and self.config.summary_proj_to_labels - and self.config.num_labels > 0 - ): - num_classes = self.config.num_labels - else: - num_classes = self.config.hidden_size - self.summary = nn.Dense(num_classes, dtype=self.dtype) - - activation_string = getattr(self.config, "summary_activation", None) - self.activation = ACT2FN[activation_string] if activation_string else lambda x: x # noqa F407 - - self.first_dropout = identity - if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0: - self.first_dropout = nn.Dropout(self.config.summary_first_dropout) - - self.last_dropout = identity - if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0: - self.last_dropout = nn.Dropout(self.config.summary_last_dropout) - - def __call__(self, hidden_states, cls_index=None, deterministic: bool = True): - """ - Compute a single vector summary of a sequence hidden states. - - Args: - hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`): - The hidden states of the last layer. - cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): - Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. - - Returns: - `jnp.ndarray`: The summary of the sequence hidden states. - """ - # NOTE: this does "first" type summary always - output = hidden_states[:, 0] - output = self.first_dropout(output, deterministic=deterministic) - output = self.summary(output) - output = self.activation(output) - output = self.last_dropout(output, deterministic=deterministic) - return output - - -class FlaxElectraForMultipleChoiceModule(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.electra = FlaxElectraModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype) - self.classifier = nn.Dense(1, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - num_choices = input_ids.shape[1] - input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None - token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None - position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None - - # Model - outputs = self.electra( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - pooled_output = self.sequence_summary(hidden_states, deterministic=deterministic) - logits = self.classifier(pooled_output) - - reshaped_logits = logits.reshape(-1, num_choices) - - if not return_dict: - return (reshaped_logits,) + outputs[1:] - - return FlaxMultipleChoiceModelOutput( - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - ELECTRA_START_DOCSTRING, -) -class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel): - module_class = FlaxElectraForMultipleChoiceModule - - -# adapt docstring slightly for FlaxElectraForMultipleChoice -overwrite_call_docstring( - FlaxElectraForMultipleChoice, ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") -) -append_call_sample_docstring( - FlaxElectraForMultipleChoice, - _CHECKPOINT_FOR_DOC, - FlaxMultipleChoiceModelOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxElectraForQuestionAnsweringModule(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.electra = FlaxElectraModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.electra( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - logits = self.qa_outputs(hidden_states) - start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - return (start_logits, end_logits) + outputs[1:] - - return FlaxQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - ELECTRA_START_DOCSTRING, -) -class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel): - module_class = FlaxElectraForQuestionAnsweringModule - - -append_call_sample_docstring( - FlaxElectraForQuestionAnswering, - _CHECKPOINT_FOR_DOC, - FlaxQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxElectraClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(classifier_dropout) - self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__(self, hidden_states, deterministic: bool = True): - x = hidden_states[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x, deterministic=deterministic) - x = self.dense(x) - x = ACT2FN["gelu"](x) # although BERT uses tanh here, it seems Electra authors used gelu - x = self.dropout(x, deterministic=deterministic) - x = self.out_proj(x) - return x - - -class FlaxElectraForSequenceClassificationModule(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.electra = FlaxElectraModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.electra( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - logits = self.classifier(hidden_states, deterministic=deterministic) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - ELECTRA_START_DOCSTRING, -) -class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel): - module_class = FlaxElectraForSequenceClassificationModule - - -append_call_sample_docstring( - FlaxElectraForSequenceClassification, - _CHECKPOINT_FOR_DOC, - FlaxSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxElectraForCausalLMModule(nn.Module): - config: ElectraConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.electra = FlaxElectraModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) - if self.config.tie_word_embeddings: - self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) - else: - self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask: Optional[jnp.ndarray] = None, - token_type_ids: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - outputs = self.electra( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - prediction_scores = self.generator_predictions(hidden_states) - - if self.config.tie_word_embeddings: - shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T) - else: - prediction_scores = self.generator_lm_head(prediction_scores) - - if not return_dict: - return (prediction_scores,) + outputs[1:] - - return FlaxCausalLMOutputWithCrossAttentions( - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - """ - Electra Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for - autoregressive tasks. - """, - ELECTRA_START_DOCSTRING, -) -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->Electra -class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel): - module_class = FlaxElectraForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyway. - # Thus, we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxElectraForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutputWithCrossAttentions, - _CONFIG_FOR_DOC, -) - - -__all__ = [ - "FlaxElectraForCausalLM", - "FlaxElectraForMaskedLM", - "FlaxElectraForMultipleChoice", - "FlaxElectraForPreTraining", - "FlaxElectraForQuestionAnswering", - "FlaxElectraForSequenceClassification", - "FlaxElectraForTokenClassification", - "FlaxElectraModel", - "FlaxElectraPreTrainedModel", -] diff --git a/src/transformers/models/electra/modeling_tf_electra.py b/src/transformers/models/electra/modeling_tf_electra.py deleted file mode 100644 index 3a5c33e503d7..000000000000 --- a/src/transformers/models/electra/modeling_tf_electra.py +++ /dev/null @@ -1,1775 +0,0 @@ -# coding=utf-8 -# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF Electra model.""" - -from __future__ import annotations - -import math -import warnings -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFSequenceSummary, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_electra import ElectraConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator" -_CONFIG_FOR_DOC = "ElectraConfig" - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Electra -class TFElectraSelfAttention(keras.layers.Layer): - def __init__(self, config: ElectraConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - - self.is_decoder = config.is_decoder - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concat([past_key_value[0], key_layer], axis=2) - value_layer = tf.concat([past_key_value[1], value_layer], axis=2) - else: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFElectraModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Electra -class TFElectraSelfOutput(keras.layers.Layer): - def __init__(self, config: ElectraConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Electra -class TFElectraAttention(keras.layers.Layer): - def __init__(self, config: ElectraConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFElectraSelfAttention(config, name="self") - self.dense_output = TFElectraSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - # add attentions (possibly with past_key_value) if we output them - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Electra -class TFElectraIntermediate(keras.layers.Layer): - def __init__(self, config: ElectraConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Electra -class TFElectraOutput(keras.layers.Layer): - def __init__(self, config: ElectraConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Electra -class TFElectraLayer(keras.layers.Layer): - def __init__(self, config: ElectraConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFElectraAttention(config, name="attention") - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TFElectraAttention(config, name="crossattention") - self.intermediate = TFElectraIntermediate(config, name="intermediate") - self.bert_output = TFElectraOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_value: tuple[tf.Tensor] | None, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=self_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - input_tensor=attention_output, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + outputs # add attentions if we output them - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - if getattr(self, "crossattention", None) is not None: - with tf.name_scope(self.crossattention.name): - self.crossattention.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Electra -class TFElectraEncoder(keras.layers.Layer): - def __init__(self, config: ElectraConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layer = [TFElectraLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_values: tuple[tuple[tf.Tensor]] | None, - use_cache: bool | None, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - past_key_value = past_key_values[i] if past_key_values is not None else None - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - if self.config.add_cross_attention and encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None - ) - - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Electra -class TFElectraPooler(keras.layers.Layer): - def __init__(self, config: ElectraConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings with Albert->Electra -class TFElectraEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config: ElectraConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = config.embedding_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.embedding_size]) - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call - def call( - self, - input_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - past_key_values_length=0, - training: bool = False, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - if input_ids is None and inputs_embeds is None: - raise ValueError("Need to provide either `input_ids` or `input_embeds`.") - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - position_ids = tf.expand_dims( - tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 - ) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -class TFElectraDiscriminatorPredictions(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense(config.hidden_size, name="dense") - self.dense_prediction = keras.layers.Dense(1, name="dense_prediction") - self.config = config - - def call(self, discriminator_hidden_states, training=False): - hidden_states = self.dense(discriminator_hidden_states) - hidden_states = get_tf_activation(self.config.hidden_act)(hidden_states) - logits = tf.squeeze(self.dense_prediction(hidden_states), -1) - - return logits - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "dense_prediction", None) is not None: - with tf.name_scope(self.dense_prediction.name): - self.dense_prediction.build([None, None, self.config.hidden_size]) - - -class TFElectraGeneratorPredictions(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dense = keras.layers.Dense(config.embedding_size, name="dense") - self.config = config - - def call(self, generator_hidden_states, training=False): - hidden_states = self.dense(generator_hidden_states) - hidden_states = get_tf_activation("gelu")(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.embedding_size]) - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFElectraPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ElectraConfig - base_model_prefix = "electra" - # When the model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - -@keras_serializable -class TFElectraMainLayer(keras.layers.Layer): - config_class = ElectraConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.is_decoder = config.is_decoder - - self.embeddings = TFElectraEmbeddings(config, name="embeddings") - - if config.embedding_size != config.hidden_size: - self.embeddings_project = keras.layers.Dense(config.hidden_size, name="embeddings_project") - - self.encoder = TFElectraEncoder(config, name="encoder") - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, value): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - def get_extended_attention_mask(self, attention_mask, input_shape, dtype, past_key_values_length=0): - batch_size, seq_length = input_shape - - if attention_mask is None: - attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(attention_mask) - - mask_seq_length = seq_length + past_key_values_length - # Copied from `modeling_tf_t5.py` - # Provided a padding mask of dimensions [batch_size, mask_seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if self.is_decoder: - seq_ids = tf.range(mask_seq_length) - causal_mask = tf.less_equal( - tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), - seq_ids[None, :, None], - ) - causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) - extended_attention_mask = causal_mask * attention_mask[:, None, :] - attention_mask_shape = shape_list(extended_attention_mask) - extended_attention_mask = tf.reshape( - extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) - ) - if past_key_values_length > 0: - extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] - else: - extended_attention_mask = tf.reshape( - attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=dtype) - one_cst = tf.constant(1.0, dtype=dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - return extended_attention_mask - - def get_head_mask(self, head_mask): - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - return head_mask - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - if not self.config.is_decoder: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_key_values_length = 0 - past_key_values = [None] * len(self.encoder.layer) - else: - past_key_values_length = shape_list(past_key_values[0][0])[-2] - - if attention_mask is None: - attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - hidden_states = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - training=training, - ) - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape, hidden_states.dtype, past_key_values_length - ) - - # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.is_decoder and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) - num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) - if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, - # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) - - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - else: - encoder_extended_attention_mask = None - - head_mask = self.get_head_mask(head_mask) - - if hasattr(self, "embeddings_project"): - hidden_states = self.embeddings_project(hidden_states, training=training) - - hidden_states = self.encoder( - hidden_states=hidden_states, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "embeddings_project", None) is not None: - with tf.name_scope(self.embeddings_project.name): - self.embeddings_project.build([None, None, self.config.embedding_size]) - - -@dataclass -class TFElectraForPreTrainingOutput(ModelOutput): - """ - Output type of [`TFElectraForPreTraining`]. - - Args: - loss (*optional*, returned when `labels` is provided, `tf.Tensor` of shape `(1,)`): - Total loss of the ELECTRA objective. - logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Prediction scores of the head (scores for each token before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -ELECTRA_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`ElectraConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -ELECTRA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to " - "the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the " - "hidden size and embedding size are different. " - "" - "Both the generator and discriminator checkpoints may be loaded into this model.", - ELECTRA_START_DOCSTRING, -) -class TFElectraModel(TFElectraPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.electra = TFElectraMainLayer(config, name="electra") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPastAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - """ - outputs = self.electra( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "electra", None) is not None: - with tf.name_scope(self.electra.name): - self.electra.build(None) - - -@add_start_docstrings( - """ - Electra model with a binary classification head on top as used during pretraining for identifying generated tokens. - - Even though both the discriminator and generator may be loaded into this model, the discriminator is the only model - of the two to have the correct classification head to be used for this model. - """, - ELECTRA_START_DOCSTRING, -) -class TFElectraForPreTraining(TFElectraPreTrainedModel): - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - - self.electra = TFElectraMainLayer(config, name="electra") - self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFElectraForPreTrainingOutput | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFElectraForPreTraining - - >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator") - >>> model = TFElectraForPreTraining.from_pretrained("google/electra-small-discriminator") - >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 - >>> outputs = model(input_ids) - >>> scores = outputs[0] - ```""" - discriminator_hidden_states = self.electra( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - discriminator_sequence_output = discriminator_hidden_states[0] - logits = self.discriminator_predictions(discriminator_sequence_output) - - if not return_dict: - return (logits,) + discriminator_hidden_states[1:] - - return TFElectraForPreTrainingOutput( - logits=logits, - hidden_states=discriminator_hidden_states.hidden_states, - attentions=discriminator_hidden_states.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "electra", None) is not None: - with tf.name_scope(self.electra.name): - self.electra.build(None) - if getattr(self, "discriminator_predictions", None) is not None: - with tf.name_scope(self.discriminator_predictions.name): - self.discriminator_predictions.build(None) - - -class TFElectraMaskedLMHead(keras.layers.Layer): - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = config.embedding_size - self.input_embeddings = input_embeddings - - def build(self, input_shape): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - super().build(input_shape) - - def get_output_embeddings(self): - return self.input_embeddings - - def set_output_embeddings(self, value): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -@add_start_docstrings( - """ - Electra model with a language modeling head on top. - - Even though both the discriminator and generator may be loaded into this model, the generator is the only model of - the two to have been trained for the masked language modeling task. - """, - ELECTRA_START_DOCSTRING, -) -class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLoss): - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - - self.config = config - self.electra = TFElectraMainLayer(config, name="electra") - self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions") - - if isinstance(config.hidden_act, str): - self.activation = get_tf_activation(config.hidden_act) - else: - self.activation = config.hidden_act - - self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") - - def get_lm_head(self): - return self.generator_lm_head - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.generator_lm_head.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="google/electra-small-generator", - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - mask="[MASK]", - expected_output="'paris'", - expected_loss=1.22, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - generator_hidden_states = self.electra( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - generator_sequence_output = generator_hidden_states[0] - prediction_scores = self.generator_predictions(generator_sequence_output, training=training) - prediction_scores = self.generator_lm_head(prediction_scores, training=training) - loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - - if not return_dict: - output = (prediction_scores,) + generator_hidden_states[1:] - - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=generator_hidden_states.hidden_states, - attentions=generator_hidden_states.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "electra", None) is not None: - with tf.name_scope(self.electra.name): - self.electra.build(None) - if getattr(self, "generator_predictions", None) is not None: - with tf.name_scope(self.generator_predictions.name): - self.generator_predictions.build(None) - if getattr(self, "generator_lm_head", None) is not None: - with tf.name_scope(self.generator_lm_head.name): - self.generator_lm_head.build(None) - - -class TFElectraClassificationHead(keras.layers.Layer): - """Head for sentence-level classification tasks.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - classifier_dropout = ( - config.classifhidden_dropout_probier_dropout - if config.classifier_dropout is not None - else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.out_proj = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" - ) - self.config = config - - def call(self, inputs, **kwargs): - x = inputs[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x) - x = self.dense(x) - x = get_tf_activation("gelu")(x) # although BERT uses tanh here, it seems Electra authors used gelu here - x = self.dropout(x) - x = self.out_proj(x) - - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - ELECTRA_START_DOCSTRING, -) -class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.electra = TFElectraMainLayer(config, name="electra") - self.classifier = TFElectraClassificationHead(config, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="bhadresh-savani/electra-base-emotion", - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="'joy'", - expected_loss=0.06, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.electra( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - logits = self.classifier(outputs[0]) - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[1:] - - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "electra", None) is not None: - with tf.name_scope(self.electra.name): - self.electra.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - ELECTRA_START_DOCSTRING, -) -class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.electra = TFElectraMainLayer(config, name="electra") - self.sequence_summary = TFSequenceSummary( - config, initializer_range=config.initializer_range, name="sequence_summary" - ) - self.classifier = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - flat_inputs_embeds = ( - tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - outputs = self.electra( - input_ids=flat_input_ids, - attention_mask=flat_attention_mask, - token_type_ids=flat_token_type_ids, - position_ids=flat_position_ids, - head_mask=head_mask, - inputs_embeds=flat_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - logits = self.sequence_summary(outputs[0]) - logits = self.classifier(logits) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[1:] - - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "electra", None) is not None: - with tf.name_scope(self.electra.name): - self.electra.build(None) - if getattr(self, "sequence_summary", None) is not None: - with tf.name_scope(self.sequence_summary.name): - self.sequence_summary.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Electra model with a token classification head on top. - - Both the discriminator and generator may be loaded into this model. - """, - ELECTRA_START_DOCSTRING, -) -class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassificationLoss): - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - - self.electra = TFElectraMainLayer(config, name="electra") - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="bhadresh-savani/electra-base-discriminator-finetuned-conll03-english", - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="['B-LOC', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'I-LOC']", - expected_loss=0.11, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - discriminator_hidden_states = self.electra( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - discriminator_sequence_output = discriminator_hidden_states[0] - discriminator_sequence_output = self.dropout(discriminator_sequence_output) - logits = self.classifier(discriminator_sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + discriminator_hidden_states[1:] - - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=discriminator_hidden_states.hidden_states, - attentions=discriminator_hidden_states.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "electra", None) is not None: - with tf.name_scope(self.electra.name): - self.electra.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Electra Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - ELECTRA_START_DOCSTRING, -) -class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.electra = TFElectraMainLayer(config, name="electra") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="bhadresh-savani/electra-base-squad2", - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - qa_target_start_index=11, - qa_target_end_index=12, - expected_output="'a nice puppet'", - expected_loss=2.64, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - discriminator_hidden_states = self.electra( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - discriminator_sequence_output = discriminator_hidden_states[0] - logits = self.qa_outputs(discriminator_sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = ( - start_logits, - end_logits, - ) + discriminator_hidden_states[1:] - - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=discriminator_hidden_states.hidden_states, - attentions=discriminator_hidden_states.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "electra", None) is not None: - with tf.name_scope(self.electra.name): - self.electra.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFElectraForMaskedLM", - "TFElectraForMultipleChoice", - "TFElectraForPreTraining", - "TFElectraForQuestionAnswering", - "TFElectraForSequenceClassification", - "TFElectraForTokenClassification", - "TFElectraModel", - "TFElectraPreTrainedModel", -] diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py deleted file mode 100644 index 4a27c23c3c69..000000000000 --- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +++ /dev/null @@ -1,901 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Classes to support Flax Encoder-Decoder architectures""" - -import os -from typing import Optional, Union - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput -from ...modeling_flax_utils import FlaxPreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from ..auto.configuration_auto import AutoConfig -from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM -from .configuration_encoder_decoder import EncoderDecoderConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "EncoderDecoderConfig" - -ENCODER_DECODER_START_DOCSTRING = r""" - This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the - encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via - [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`] - function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream - generative task, like summarization. - - The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation - tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation - Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi - Zhou, Wei Li, Peter J. Liu. - - After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models - (see the examples for more information). - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Parameters: - config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -ENCODER_DECODER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be - created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` - and prepending them with the `decoder_start_token_id`. - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.encoder.max_position_embeddings - 1]`. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.decoder.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple. -""" - -ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.encoder.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple. -""" - -ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be - created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` - and prepending them with the `decoder_start_token_id`. - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.decoder.max_position_embeddings - 1]`. - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a - plain tuple. -""" - - -class FlaxEncoderDecoderModule(nn.Module): - config: EncoderDecoderConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - encoder_config = self.config.encoder - decoder_config = self.config.decoder - - # Copied from `modeling_hybrid_clip.py` with modifications. - from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING - - encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class - decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class - - self.encoder = encoder_module(encoder_config, dtype=self.dtype) - self.decoder = decoder_module(decoder_config, dtype=self.dtype) - - # encoder outputs might need to be projected to different dimension for decoder - if ( - self.encoder.config.hidden_size != self.decoder.config.hidden_size - and self.decoder.config.cross_attention_hidden_size is None - ): - self.enc_to_dec_proj = nn.Dense( - self.decoder.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), - dtype=self.dtype, - ) - else: - self.enc_to_dec_proj = None - - def _get_encoder_module(self): - return self.encoder - - def _get_projection_module(self): - return self.enc_to_dec_proj - - def _get_decoder_module(self): - return self.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - encoder_hidden_states = encoder_outputs[0] - - # optionally project encoder_hidden_states - if self.enc_to_dec_proj is not None: - encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) - - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqLMOutput( - logits=decoder_outputs.logits, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) -class FlaxEncoderDecoderModel(FlaxPreTrainedModel): - r""" - [`FlaxEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with - the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one as - decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the - encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder. - """ - - config_class = EncoderDecoderConfig - base_model_prefix = "encoder_decoder" - module_class = FlaxEncoderDecoderModule - - def __init__( - self, - config: EncoderDecoderConfig, - input_shape: Optional[tuple] = None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - if input_shape is None: - input_shape = ((1, 1), (1, 1)) - - if not _do_init: - raise ValueError( - "`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." - ) - - if config.decoder.cross_attention_hidden_size is not None: - if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: - raise ValueError( - "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" - f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" - f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" - " `config.encoder.hidden_size`." - ) - - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - encoder_input_shape, decoder_input_shape = input_shape - - # init input tensors - input_ids = jnp.zeros(encoder_input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape - if not decoder_batch_size == batch_size: - raise ValueError( - f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder" - f" and {decoder_batch_size} for decoder." - ) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length) - ) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC) - def encode( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer - - >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized - >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") - - >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased") - - >>> text = "My friends are cool but they eat too many carbs." - >>> input_ids = tokenizer.encode(text, return_tensors="np") - >>> encoder_outputs = model.encode(input_ids) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_ids, attention_mask, position_ids, **kwargs) - - outputs = self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - if return_dict: - outputs = FlaxBaseModelOutput( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - return outputs - - @add_start_docstrings(ENCODER_DECODER_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer - >>> import jax.numpy as jnp - - >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized - >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") - - >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased") - - >>> text = "My friends are cool but they eat too many carbs." - >>> input_ids = tokenizer.encode(text, max_length=1024, return_tensors="np") - >>> encoder_outputs = model.encode(input_ids) - - >>> decoder_start_token_id = model.config.decoder.bos_token_id - >>> decoder_input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxBartAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward( - module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs - ): - projection_module = module._get_projection_module() - decoder_module = module._get_decoder_module() - - # optionally project encoder_hidden_states - if projection_module is not None: - encoder_hidden_states = projection_module(encoder_hidden_states) - - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - encoder_hidden_states=encoder_hidden_states, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Examples: - - ```python - >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer, GPT2Tokenizer - - >>> # load a fine-tuned bert2gpt2 model - >>> model = FlaxEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16") - >>> # load input & output tokenizer - >>> tokenizer_input = BertTokenizer.from_pretrained("google-bert/bert-base-cased") - >>> tokenizer_output = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - - >>> article = '''Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members - >>> singing a racist chant. SAE's national chapter suspended the students, - >>> but University of Oklahoma President David Boren took it a step further, - >>> saying the university's affiliation with the fraternity is permanently done.''' - - >>> input_ids = tokenizer_input(article, add_special_tokens=True, return_tensors="np").input_ids - - >>> # use GPT2's eos_token as the pad as well as eos token - >>> model.config.eos_token_id = model.config.decoder.eos_token_id - >>> model.config.pad_token_id = model.config.eos_token_id - - >>> sequences = model.generate(input_ids, num_beams=4, max_length=12).sequences - - >>> summary = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)[0] - >>> assert summary == "SAS Alpha Epsilon suspended Sigma Alpha Epsilon members" - ``` - """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # prepare decoder inputs - if decoder_input_ids is None: - raise ValueError( - "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must" - " be specified as an input argument." - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - if decoder_position_ids is None: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - decoder_position_ids = jnp.broadcast_to( - jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) - ) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": decoder_position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 - return model_kwargs - - @classmethod - def from_encoder_decoder_pretrained( - cls, - encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, - decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, - *model_args, - **kwargs, - ) -> FlaxPreTrainedModel: - r""" - Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model - checkpoints. - - Params: - encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*): - Information necessary to initiate the encoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`): - Information necessary to initiate the decoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - model_args (remaining positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). - - - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. - - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. - - To update the parent model configuration, do not use a prefix for each configuration parameter. - - Behaves differently depending on whether a `config` is provided or automatically loaded. - - Example: - - ```python - >>> from transformers import FlaxEncoderDecoderModel - - >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized - >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") - >>> # saving model after fine-tuning - >>> model.save_pretrained("./bert2gpt2") - >>> # load fine-tuned model - >>> model = FlaxEncoderDecoderModel.from_pretrained("./bert2gpt2") - ```""" - - kwargs_encoder = { - argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") - } - - kwargs_decoder = { - argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") - } - - # remove encoder, decoder kwargs from kwargs - for key in kwargs_encoder: - del kwargs["encoder_" + key] - for key in kwargs_decoder: - del kwargs["decoder_" + key] - - # Load and initialize the encoder and decoder - # The distinction between encoder and decoder at the model level is made - # by the value of the flag `is_decoder` that we need to set correctly. - encoder = kwargs_encoder.pop("model", None) - if encoder is None: - if encoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_encoder: - encoder_config, kwargs_encoder = AutoConfig.from_pretrained( - encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True - ) - if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( - f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " - "from a decoder model. Cross-attention and causal mask are disabled." - ) - encoder_config.is_decoder = False - encoder_config.add_cross_attention = False - - kwargs_encoder["config"] = encoder_config - - encoder = FlaxAutoModel.from_pretrained( - encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder - ) - - decoder = kwargs_decoder.pop("model", None) - if decoder is None: - if decoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_decoder: - decoder_config, kwargs_decoder = AutoConfig.from_pretrained( - decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True - ) - if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: - logger.info( - f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" - f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" - f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." - ) - decoder_config.is_decoder = True - decoder_config.add_cross_attention = True - - kwargs_decoder["config"] = decoder_config - - if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: - logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " - f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " - "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " - "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " - "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" - ) - - decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) - - # instantiate config with corresponding kwargs - dtype = kwargs.pop("dtype", jnp.float32) - config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) - - # init model - model = cls(config, dtype=dtype) - model.params["encoder"] = encoder.params - model.params["decoder"] = decoder.params - - return model - - -__all__ = ["FlaxEncoderDecoderModel"] diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py deleted file mode 100644 index 7e5343d20049..000000000000 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ /dev/null @@ -1,661 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Classes to support TF Encoder-Decoder architectures""" - -from __future__ import annotations - -import inspect -import re -import warnings - -import numpy as np -import tensorflow as tf - -from ...configuration_utils import PretrainedConfig -from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - get_initializer, - keras, - unpack_inputs, -) -from ...tf_utils import shape_list -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from ..auto.configuration_auto import AutoConfig -from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM -from .configuration_encoder_decoder import EncoderDecoderConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "EncoderDecoderConfig" - -DEPRECATION_WARNING = ( - "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the" - " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if" - " fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the" - " labels, no need to pass them yourself anymore." -) - -ENCODER_DECODER_START_DOCSTRING = r""" - This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the - encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via - [`~TFAutoModel.from_pretrained`] function and the decoder is loaded via [`~TFAutoModelForCausalLM.from_pretrained`] - function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream - generative task, like summarization. - - The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation - tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation - Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi - Zhou, Wei Li, Peter J. Liu. - - After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models - (see the examples for more information). - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -ENCODER_DECODER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - Provide for sequence to sequence training to the decoder. Indices can be obtained using - [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for - details. - decoder_attention_mask (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*): - This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` (`tf.Tensor` of shape `({0}, hidden_size)`) is a tensor of hidden-states at the output - of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `({0})`. - inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - decoder_inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. This is useful if you want more control over how to convert `decoder_input_ids` indices - into associated vectors than the model's internal embedding lookup matrix. - labels (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0, - ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: - - - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. - - With a *decoder_* prefix which will be input as `**decoder_kwargs`` for the decoder forward function. -""" - - -def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - if pad_token_id is None: - raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") - pad_token_id = tf.cast(pad_token_id, input_ids.dtype) - - if decoder_start_token_id is None: - raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") - decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) - - start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) -class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): - r""" - [`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one - of the base model classes of the library as encoder and another one as decoder when created with the - [`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class - method for the decoder. - """ - - config_class = EncoderDecoderConfig - base_model_prefix = "encoder_decoder" - load_weight_prefix = "tf_encoder_decoder_model" - - def __init__( - self, - config: PretrainedConfig | None = None, - encoder: TFPreTrainedModel | None = None, - decoder: TFPreTrainedModel | None = None, - ): - if config is None and (encoder is None or decoder is None): - raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") - if config is None: - config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) - else: - if not isinstance(config, self.config_class): - raise ValueError(f"config: {config} has to be of type {self.config_class}") - - if config.decoder.cross_attention_hidden_size is not None: - if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: - raise ValueError( - "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" - f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" - f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" - " `config.encoder.hidden_size`." - ) - - # initialize with config - super().__init__(config) - - if encoder is None: - encoder = TFAutoModel.from_config(config.encoder, name="encoder") - - if decoder is None: - decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder") - - self.encoder = encoder - self.decoder = decoder - - if self.encoder.config.to_dict() != self.config.encoder.to_dict(): - logger.warning( - f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" - f" {self.config.encoder}" - ) - if self.decoder.config.to_dict() != self.config.decoder.to_dict(): - logger.warning( - f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" - f" {self.config.decoder}" - ) - - # make sure that the individual model's config refers to the shared config - # so that the updates to the config will be synced - self.encoder.config = self.config.encoder - self.decoder.config = self.config.decoder - - # encoder outputs might need to be projected to different dimension for decoder - if ( - self.encoder.config.hidden_size != self.decoder.config.hidden_size - and self.decoder.config.cross_attention_hidden_size is None - ): - self.enc_to_dec_proj = keras.layers.Dense( - units=self.decoder.config.hidden_size, - kernel_initializer=get_initializer(config.encoder.initializer_range), - name="enc_to_dec_proj", - ) - - if self.encoder.get_output_embeddings() is not None: - raise ValueError( - f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" - ) - - decoder_signature = set(inspect.signature(self.decoder.call).parameters.keys()) - if "encoder_hidden_states" not in decoder_signature: - raise ValueError( - "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the " - "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" - ) - - def get_encoder(self): - return self.encoder - - def get_input_embeddings(self): - return self.encoder.get_input_embeddings() - - def get_output_embeddings(self): - return self.decoder.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - return self.decoder.set_output_embeddings(new_embeddings) - - def tf_to_pt_weight_rename(self, tf_weight): - # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models - # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. - # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption - # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's - # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name! - - # This override is only needed in the case where we're crossloading weights from PT. However, since weights are - # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file. - # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it - # or not. - encoder_model_type = self.config.encoder.model_type - if "encoder" in tf_weight and "decoder" not in tf_weight: - return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),) - else: - return (tf_weight,) - - @classmethod - def from_encoder_decoder_pretrained( - cls, - encoder_pretrained_model_name_or_path: str | None = None, - decoder_pretrained_model_name_or_path: str | None = None, - *model_args, - **kwargs, - ) -> TFPreTrainedModel: - r""" - Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model - checkpoints. - - - Params: - encoder_pretrained_model_name_or_path (`str`, *optional*): - Information necessary to initiate the encoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case, - `encoder_from_pt` should be set to `True`. - - decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): - Information necessary to initiate the decoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case, - `decoder_from_pt` should be set to `True`. - - model_args (remaining positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). - - - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. - - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. - - To update the parent model configuration, do not use a prefix for each configuration parameter. - - Behaves differently depending on whether a `config` is provided or automatically loaded. - - Example: - - ```python - >>> from transformers import TFEncoderDecoderModel - - >>> # initialize a bert2gpt2 from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized - >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "openai-community/gpt2") - >>> # saving model after fine-tuning - >>> model.save_pretrained("./bert2gpt2") - >>> # load fine-tuned model - >>> model = TFEncoderDecoderModel.from_pretrained("./bert2gpt2") - ```""" - - kwargs_encoder = { - argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") - } - - kwargs_decoder = { - argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") - } - - # remove encoder, decoder kwargs from kwargs - for key in kwargs_encoder: - del kwargs["encoder_" + key] - for key in kwargs_decoder: - del kwargs["decoder_" + key] - - # Load and initialize the encoder and decoder - # The distinction between encoder and decoder at the model level is made - # by the value of the flag `is_decoder` that we need to set correctly. - encoder = kwargs_encoder.pop("model", None) - if encoder is None: - if encoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_encoder: - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) - if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( - f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " - "from a decoder model. Cross-attention and causal mask are disabled." - ) - encoder_config.is_decoder = False - encoder_config.add_cross_attention = False - - kwargs_encoder["config"] = encoder_config - - kwargs_encoder["name"] = "encoder" - kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix - encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) - - decoder = kwargs_decoder.pop("model", None) - if decoder is None: - if decoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_decoder: - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) - if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: - logger.info( - f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" - f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" - f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." - ) - decoder_config.is_decoder = True - decoder_config.add_cross_attention = True - - kwargs_decoder["config"] = decoder_config - - if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: - logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " - f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " - "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " - "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " - "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" - ) - - kwargs_decoder["name"] = "decoder" - kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix - decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) - - # Make sure these 2 `keras.Model` have fixed names so `from_pretrained` could load model weights correctly. - if encoder.name != "encoder": - raise ValueError("encoder model must be created with the name `encoder`.") - if decoder.name != "decoder": - raise ValueError("decoder model must be created with the name `decoder`.") - - # instantiate config with corresponding kwargs - config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) - return cls(encoder=encoder, decoder=decoder, config=config) - - @unpack_inputs - @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - labels: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, - ) -> TFSeq2SeqLMOutput | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import TFEncoderDecoderModel, BertTokenizer - - >>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized - >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") - - >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased") - - >>> # forward - >>> input_ids = tokenizer.encode( - ... "Hello, my dog is cute", add_special_tokens=True, return_tensors="tf" - ... ) # Batch size 1 - >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids) - - >>> # training - >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids) - >>> loss, logits = outputs.loss, outputs.logits - - >>> # save and load from pretrained - >>> model.save_pretrained("bert2gpt2") - >>> model = TFEncoderDecoderModel.from_pretrained("bert2gpt2") - - >>> # generation - >>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.bos_token_id) - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} - - kwargs_decoder = { - argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") - } - - # Let the user be responsible for the expected format. - if encoder_outputs is not None: - if return_dict and not isinstance(encoder_outputs, ModelOutput): - raise ValueError( - "If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of " - f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`." - ) - - if encoder_outputs is None: - encoder_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "inputs_embeds": inputs_embeds, - "output_attentions": output_attentions, - "output_hidden_states": output_hidden_states, - "return_dict": return_dict, - "training": training, - } - - # Add arguments to encoder from `kwargs_encoder` - encoder_inputs.update(kwargs_encoder) - - # Handle the case where the inputs are passed as a single dict which contains `labels`. - # The `labels` shouldn't be passed to `self.encoder` below, because it is a based model without this - # parameter (otherwise, an error occurs when `input_processing` is called inside `self.encoder.call()`). - if "labels" in encoder_inputs: - labels = encoder_inputs.pop("labels") - - # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`. - if "decoder_input_ids" in encoder_inputs: - decoder_input_ids = encoder_inputs.pop("decoder_input_ids") - # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`. - if "decoder_attention_mask" in encoder_inputs: - decoder_attention_mask = encoder_inputs.pop("decoder_attention_mask") - - encoder_outputs = self.encoder(**encoder_inputs) - - encoder_hidden_states = encoder_outputs[0] - - # optionally project encoder_hidden_states - if ( - self.encoder.config.hidden_size != self.decoder.config.hidden_size - and self.decoder.config.cross_attention_hidden_size is None - ): - encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) - - if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - decoder_inputs = { - "input_ids": decoder_input_ids, - "attention_mask": decoder_attention_mask, - "encoder_hidden_states": encoder_hidden_states, - "encoder_attention_mask": attention_mask, - "inputs_embeds": decoder_inputs_embeds, - "output_attentions": output_attentions, - "output_hidden_states": output_hidden_states, - "use_cache": use_cache, - "past_key_values": past_key_values, - "return_dict": return_dict, - "training": training, - } - - # Add arguments to decoder from `kwargs_decoder` - decoder_inputs.update(kwargs_decoder) - - decoder_outputs = self.decoder(**decoder_inputs) - - logits = decoder_outputs[0] - - # Compute loss independent from decoder (as some shift the logits inside them) - loss = None - if labels is not None: - warnings.warn(DEPRECATION_WARNING, FutureWarning) - loss = self.hf_compute_loss(labels, logits) - - if not return_dict: - past_key_values = None - if use_cache: - past_key_values = decoder_outputs[1] - # The starting index of the remaining elements in `decoder_outputs` - start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) - - if not isinstance(encoder_outputs, tuple): - encoder_outputs = encoder_outputs.to_tuple() - output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs - output = tuple(x for x in output if x is not None) - return output - - return TFSeq2SeqLMOutput( - loss=loss, - logits=decoder_outputs.logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs - ): - decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) - decoder_attention_mask = decoder_inputs.get("attention_mask", None) - past_key_values = decoder_inputs.get("past_key_values") - if past_key_values is None: - past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2 - input_dict = { - "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "decoder_input_ids": decoder_inputs["input_ids"], - # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete - "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]), - "past_key_values": past_key_values, - "use_cache": use_cache, - } - return input_dict - - def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - - def resize_token_embeddings(self, *args, **kwargs): - raise NotImplementedError( - "Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported.Please use the" - " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" - " model.decoder.resize_token_embeddings(...))" - ) - - def _reorder_cache(self, past, beam_idx): - # apply decoder cache reordering here - return self.decoder._reorder_cache(past, beam_idx) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "enc_to_dec_proj", None) is not None: - with tf.name_scope(self.enc_to_dec_proj.name): - self.enc_to_dec_proj.build([None, None, self.encoder.config.hidden_size]) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -__all__ = ["TFEncoderDecoderModel"] diff --git a/src/transformers/models/esm/modeling_tf_esm.py b/src/transformers/models/esm/modeling_tf_esm.py deleted file mode 100644 index 3fd066868f0e..000000000000 --- a/src/transformers/models/esm/modeling_tf_esm.py +++ /dev/null @@ -1,1574 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch ESM model.""" - -from __future__ import annotations - -import os - -import numpy as np -import tensorflow as tf - -from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFBaseModelOutputWithPoolingAndCrossAttentions, - TFMaskedLMOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - shape_list, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, stable_softmax -from ...utils import logging -from .configuration_esm import EsmConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D" -_CONFIG_FOR_DOC = "EsmConfig" - - -def rotate_half(x): - x1, x2 = tf.split(x, 2, axis=-1) - return tf.concat((-x2, x1), axis=-1) - - -def apply_rotary_pos_emb(x, cos, sin): - cos = cos[:, :, : tf.shape(x)[-2], :] - sin = sin[:, :, : tf.shape(x)[-2], :] - - return (x * cos) + (rotate_half(x) * sin) - - -def symmetrize(x): - "Make layer symmetric in final two dimensions, used for contact prediction." - return x + tf.linalg.matrix_transpose(x) # Transposes last two dimensions only - - -def average_product_correct(x): - "Perform average product correct, used for contact prediction." - a1 = tf.reduce_sum(x, -1, keepdims=True) - a2 = tf.reduce_sum(x, -2, keepdims=True) - a12 = tf.reduce_sum(x, (-1, -2), keepdims=True) - - avg = a1 * a2 - avg = avg / a12 - normalized = x - avg - return normalized - - -class TFRotaryEmbedding(keras.layers.Layer): - """ - Rotary position embeddings based on those in - [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation - matrices which depend on their relative positions. - """ - - def __init__(self, dim: int, name=None): - super().__init__(name=name) - # Matt: The PyTorch version of this layer does a lot of work to cache values, but we just rely on TF compilation - # and/or XLA to sort out constants like that. It actually may not seem like this layer needs to be stateful at - # all when we benefit from TF compilation, but it does. The reason is that self.inv_freq is a buffer in the - # original implementation, but all the shared ESM checkpoints were trained with fp16 params. This means that - # the inv_freq tensor was stored as a float16, and we need to replicate those lower-precision values or our - # models give different outputs from the original. - self.dim = dim - - def build(self, input_shape): - super().build(input_shape) - self.inv_freq = self.add_weight( - "inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0), trainable=False - ) - self.inv_freq.assign( - 1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim)) - ) - - def _compute_cos_sin(self, x, seq_dimension=2): - seq_len = tf.shape(x)[seq_dimension] - - t = tf.range(seq_len, dtype=self.inv_freq.dtype) - freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication - emb = tf.concat((freqs, freqs), axis=-1)[None, None, :, :] - - return tf.cos(emb), tf.sin(emb) - - def call(self, q: tf.Tensor, k: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]: - cos_emb, sin_emb = self._compute_cos_sin(k, seq_dimension=-2) - - return ( - apply_rotary_pos_emb(q, cos_emb, sin_emb), - apply_rotary_pos_emb(k, cos_emb, sin_emb), - ) - - -class TFEsmContactPredictionHead(keras.layers.Layer): - """Performs symmetrization, apc, and computes a logistic regression on the output features""" - - def __init__( - self, - in_features: int, - bias=True, - eos_idx: int = 2, - name=None, - ): - super().__init__(name=name) - self.eos_idx = eos_idx - self.in_features = in_features - self.regression = keras.layers.Dense(1, use_bias=bias, activation="sigmoid", name="regression") - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "regression", None) is not None: - with tf.name_scope(self.regression.name): - self.regression.build((None, self.in_features)) - - def call(self, tokens, attentions): - # remove eos token attentions - eos_mask = tf.cast(tokens != self.eos_idx, attentions.dtype) - eos_mask = tf.expand_dims(eos_mask, 1) * tf.expand_dims(eos_mask, 2) - attentions = attentions * eos_mask[:, None, None, :, :] - attentions = attentions[..., :-1, :-1] - # remove cls token attentions - attentions = attentions[..., 1:, 1:] - batch_size, layers, heads, seqlen, _ = shape_list(attentions) - attentions = tf.reshape(attentions, (batch_size, layers * heads, seqlen, seqlen)) - - # features: batch x channels x tokens x tokens (symmetric) - attentions = average_product_correct(symmetrize(attentions)) - attentions = tf.transpose(attentions, perm=(0, 2, 3, 1)) - return tf.squeeze(self.regression(attentions), 3) - - -class TFEsmEmbeddings(keras.layers.Layer): - """ - Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. - """ - - def __init__(self, config, name=None): - super().__init__(name=name) - self.word_embeddings = keras.layers.Embedding( - config.vocab_size, - config.hidden_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="word_embeddings", - ) - self.position_embeddings = keras.layers.Embedding( - config.max_position_embeddings, - config.hidden_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="position_embeddings", - ) - - if config.emb_layer_norm_before: - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - else: - self.layer_norm = None - # Matt: I think this line was copied incorrectly from BERT, disabling for now - # self.dropout = Dropout(config.hidden_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - - self.position_ids = tf.range(config.max_position_embeddings)[None, :] - - self.padding_idx = config.pad_token_id - self.token_dropout = config.token_dropout - self.mask_token_id = config.mask_token_id - self.config = config - - def call( - self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 - ): - if position_ids is None: - if input_ids is not None: - # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) - else: - position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = self.word_embeddings(input_ids) - - # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an - # embedding_scale factor here. - embeddings = inputs_embeds - - # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout - # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, - # masked tokens are treated as if they were selected for input dropout and zeroed out. - # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by - # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). - # This is analogous to the way that dropout layers scale down outputs during evaluation when not - # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). - if self.token_dropout: - embeddings = tf.where((input_ids == self.mask_token_id)[:, :, None], 0.0, embeddings) - mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs - src_lengths = tf.cast(tf.reduce_sum(attention_mask, axis=-1), tf.float32) - masked_tokens = input_ids == self.mask_token_id - mask_ratio_observed = tf.math.count_nonzero(masked_tokens, dtype=tf.float32, axis=-1) / src_lengths - embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] - - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings - - if self.layer_norm is not None: - embeddings = self.layer_norm(embeddings) - if attention_mask is not None: - embeddings = embeddings * tf.cast(tf.expand_dims(attention_mask, -1), embeddings.dtype) - # Matt: I think this line was copied incorrectly from BERT, disabling it for now. - # embeddings = self.dropout(embeddings) - return embeddings - - def create_position_ids_from_inputs_embeds(self, inputs_embeds): - """ - We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. - - Args: - inputs_embeds: tf.Tensor - - Returns: tf.Tensor - """ - input_shape = shape_list(inputs_embeds)[:-1] - sequence_length = input_shape[1] - - position_ids = tf.range( - start=self.padding_idx + 1, limit=sequence_length + self.padding_idx + 1, dtype=tf.int64 - ) - return tf.broadcast_to(tf.expand_dims(position_ids, 0), input_shape) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "word_embeddings", None) is not None: - with tf.name_scope(self.word_embeddings.name): - self.word_embeddings.build(None) - if getattr(self, "position_embeddings", None) is not None: - with tf.name_scope(self.position_embeddings.name): - self.position_embeddings.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - - -class TFEsmSelfAttention(keras.layers.Layer): - def __init__(self, config, position_embedding_type=None, name=None): - super().__init__(name=name) - if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - self.rotary_embeddings = None - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = keras.layers.Embedding( - 2 * config.max_position_embeddings - 1, - self.attention_head_size, - embeddings_initializer=get_initializer(config.initializer_range), - ) - elif self.position_embedding_type == "rotary": - self.rotary_embeddings = TFRotaryEmbedding(dim=self.attention_head_size, name="rotary_embeddings") - - self.is_decoder = config.is_decoder - self.config = config - - def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: - new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size] - x = tf.reshape(x, new_x_shape) - return tf.transpose(x, perm=(0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - output_attentions: bool | None = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - mixed_query_layer = self.query(hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = tf.concat([past_key_value[0], key_layer], axis=2) - value_layer = tf.concat([past_key_value[1], value_layer], axis=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). - # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, - # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original - # ESM code and fix rotary embeddings. - query_layer = query_layer * self.attention_head_size**-0.5 - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - if self.position_embedding_type == "rotary": - query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = shape_list(hidden_states)[1] - position_ids_l = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), -1) - position_ids_r = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), 0) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = tf.cast(positional_embedding, query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = tf.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in EsmModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = attention_probs @ value_layer - - context_layer = tf.transpose(context_layer, perm=(0, 2, 1, 3)) - new_context_layer_shape = shape_list(context_layer)[:-2] + [self.all_head_size] - context_layer = tf.reshape(context_layer, new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - if getattr(self, "rotary_embeddings", None) is not None: - with tf.name_scope(self.rotary_embeddings.name): - self.rotary_embeddings.build(None) - - -class TFEsmSelfOutput(keras.layers.Layer): - def __init__(self, config, name=None): - super().__init__(name=name) - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states, input_tensor, training=False): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states += input_tensor - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFEsmAttention(keras.layers.Layer): - def __init__(self, config, name=None): - super().__init__(name=name) - self.self = TFEsmSelfAttention(config, name="self") - self.output_layer = TFEsmSelfOutput(config, name="output") - self.pruned_heads = set() - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - training=False, - ): - hidden_states_ln = self.LayerNorm(hidden_states) - self_outputs = self.self( - hidden_states_ln, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - training, - ) - attention_output = self.output_layer(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self", None) is not None: - with tf.name_scope(self.self.name): - self.self.build(None) - if getattr(self, "output_layer", None) is not None: - with tf.name_scope(self.output_layer.name): - self.output_layer.build(None) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFEsmIntermediate(keras.layers.Layer): - def __init__(self, config: EsmConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = tf.nn.gelu(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFEsmOutput(keras.layers.Layer): - def __init__(self, config, name=None): - super().__init__(name=name) - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states, input_tensor, training=False): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states += input_tensor - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - - -class TFEsmLayer(keras.layers.Layer): - def __init__(self, config, name=None): - super().__init__(name=name) - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.attention = TFEsmAttention(config, name="attention") - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TFEsmAttention(config) - self.intermediate = TFEsmIntermediate(config, name="intermediate") - self.output_layer = TFEsmOutput(config, name="output") - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def call( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - training=False, - ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - hidden_states, - attention_mask, - head_mask, - output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, - training=training, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise AttributeError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated" - " with cross-attention layers by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - training=training, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - layernorm_output = self.LayerNorm(attention_output) - intermediate_output = self.intermediate(hidden_states=layernorm_output) - layer_output = self.output_layer( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + outputs # add attentions if we output them - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "output_layer", None) is not None: - with tf.name_scope(self.output_layer.name): - self.output_layer.build(None) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFEsmEncoder(keras.layers.Layer): - def __init__(self, config, name=None): - super().__init__(name=name) - self.config = config - self.layer = [TFEsmLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - self.emb_layer_norm_after = keras.layers.LayerNormalization( - epsilon=config.layer_norm_eps, name="emb_layer_norm_after" - ) - - def call( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - training=False, - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None - - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - training, - ) - - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - if self.emb_layer_norm_after: - hidden_states = self.emb_layer_norm_after(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "emb_layer_norm_after", None) is not None: - with tf.name_scope(self.emb_layer_norm_after.name): - self.emb_layer_norm_after.build([None, None, self.config.hidden_size]) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Esm -class TFEsmPooler(keras.layers.Layer): - def __init__(self, config: EsmConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFEsmPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = EsmConfig - base_model_prefix = "esm" - - -ESM_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a - regular Keras model and refer to the TF/Keras documentation for all matters related to general usage and behavior. - - Parameters: - config ([`EsmConfig`]): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -ESM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.", - ESM_START_DOCSTRING, -) -class TFEsmMainLayer(keras.layers.Layer): - """ - - The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in [Attention is - all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, - Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - - To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set - to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and - `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. - """ - - _keys_to_ignore_on_load_missing = [r"position_ids"] - - def __init__(self, config, add_pooling_layer=True, name=None, **kwargs): - super().__init__(name=name, **kwargs) - - self.config = config - self.is_decoder = config.is_decoder - - self.embeddings = TFEsmEmbeddings(config, name="embeddings") - self.encoder = TFEsmEncoder(config, name="encoder") - self.pooler = TFEsmPooler(config, name="pooler") if add_pooling_layer else None - - self.contact_head = TFEsmContactPredictionHead( - in_features=self.config.num_hidden_layers * self.config.num_attention_heads, bias=True, name="contact_head" - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - if getattr(self, "contact_head", None) is not None: - with tf.name_scope(self.contact_head.name): - self.contact_head.build(None) - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.word_embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - raise NotImplementedError - - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - if not self.config.is_decoder: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_key_values_length = 0 - past_key_values = [None] * len(self.encoder.layer) - else: - past_key_values_length = shape_list(past_key_values[0][0])[-2] - - if attention_mask is None: - attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - - embedding_output = self.embeddings( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - training=training, - ) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(attention_mask) - - mask_seq_length = seq_length + past_key_values_length - # Copied from `modeling_tf_t5.py` - # Provided a padding mask of dimensions [batch_size, mask_seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if self.is_decoder: - seq_ids = tf.range(mask_seq_length) - causal_mask = tf.less_equal( - tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), - seq_ids[None, :, None], - ) - causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) - extended_attention_mask = causal_mask * attention_mask[:, None, :] - attention_mask_shape = shape_list(extended_attention_mask) - extended_attention_mask = tf.reshape( - extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) - ) - if past_key_values[0] is not None: - # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] - extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] - else: - extended_attention_mask = tf.reshape( - attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.is_decoder and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) - num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) - if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, - # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) - - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - def predict_contacts(self, tokens, attention_mask): - attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions - attns = tf.stack(attns, axis=1) # Matches the original model layout - # In the original model, attentions for padding tokens are completely zeroed out. - # This makes no difference most of the time because the other tokens won't attend to them, - # but it does for the contact prediction task, which takes attentions as input, - # so we have to mimic that here. - attention_mask = tf.cast(attention_mask, attns.dtype) - attns *= attention_mask[:, None, None, None] - attns *= attention_mask[:, None, None, :, None] - return self.contact_head(tokens, attns) - - -@add_start_docstrings( - "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.", - ESM_START_DOCSTRING, -) -class TFEsmModel(TFEsmPreTrainedModel): - def __init__(self, config: EsmConfig, add_pooling_layer=True, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.esm = TFEsmMainLayer(config, add_pooling_layer=add_pooling_layer, name="esm") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - """ - outputs = self.esm( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - def predict_contacts(self, tokens, attention_mask): - return self.esm.predict_contacts(tokens, attention_mask) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "esm", None) is not None: - with tf.name_scope(self.esm.name): - self.esm.build(None) - - -@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING) -class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss): - _keys_to_ignore_on_load_missing = [r"position_ids"] - _keys_to_ignore_on_load_unexpected = [r"pooler"] - - def __init__(self, config): - super().__init__(config) - - if config.is_decoder: - logger.warning( - "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") - self.lm_head = TFEsmLMHead(config, name="lm_head") - if config.tie_word_embeddings: - # Ensure word embeddings are built so that we actually have something to tie - with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")): - self.esm.embeddings.word_embeddings.build((None, None)) - self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0] - - def get_output_embeddings(self): - return self.lm_head.decoder - - def set_output_embeddings(self, new_embeddings): - self.lm_head.decoder = new_embeddings - - def get_lm_head(self): - return self.lm_head - - @unpack_inputs - @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - mask="", - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - labels: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - kwargs (`dict[str, any]`, *optional*, defaults to `{}`): - Used to hide legacy arguments that have been deprecated. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.esm( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - masked_lm_loss = None - if labels is not None: - masked_lm_loss = self.hf_compute_loss(labels=labels, logits=prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return TFMaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def predict_contacts(self, tokens, attention_mask): - return self.esm.predict_contacts(tokens, attention_mask) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "esm", None) is not None: - with tf.name_scope(self.esm.name): - self.esm.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -class TFEsmLMHead(keras.layers.Layer): - """ESM Head for masked language modeling.""" - - def __init__(self, config, name=None): - super().__init__(name=name) - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - if config.tie_word_embeddings: - self.decoder = None - else: - self.decoder = keras.layers.Dense( - config.vocab_size, - kernel_initializer=get_initializer(config.initializer_range), - name="decoder", - use_bias=False, - ) - self.config = config - - def build(self, input_shape=None): - # Separate bias to match the PT model and allow weight cross-loading to work - # Put it in the build so it gets the right name when adding it as a weight - if self.built: - return - self.built = True - self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True) - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - if getattr(self, "decoder", None) is not None and not self.config.tie_word_embeddings: - with tf.name_scope(self.decoder.name): - self.decoder.build([None, None, self.config.hidden_size]) - - def get_bias(self): - return {"bias": self.bias} - - def call(self, features): - x = self.dense(features) - x = tf.nn.gelu(x) - x = self.layer_norm(x) - - # project back to size of vocabulary with bias - if self.config.tie_word_embeddings: - x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias - else: - x = self.decoder(x) + self.bias - return x - - -@add_start_docstrings( - """ - ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled - output) e.g. for GLUE tasks. - """, - ESM_START_DOCSTRING, -) -class TFEsmForSequenceClassification(TFEsmPreTrainedModel, TFSequenceClassificationLoss): - _keys_to_ignore_on_load_missing = [r"position_ids"] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - - self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") - self.classifier = TFEsmClassificationHead(config, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - labels: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.esm( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "esm", None) is not None: - with tf.name_scope(self.esm.name): - self.esm.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - ESM_START_DOCSTRING, -) -class TFEsmForTokenClassification(TFEsmPreTrainedModel, TFTokenClassificationLoss): - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"position_ids"] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.classifier = keras.layers.Dense(config.num_labels, name="classifier") - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - labels: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.esm( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "esm", None) is not None: - with tf.name_scope(self.esm.name): - self.esm.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -class TFEsmClassificationHead(keras.layers.Layer): - """Head for sentence-level classification tasks.""" - - def __init__(self, config, name=None): - super().__init__(name=name) - self.dense = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.out_proj = keras.layers.Dense( - config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - activation="linear", - name="out_proj", - ) - self.config = config - - def call(self, features, training=False): - x = features[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x, training=training) - x = self.dense(x) - x = self.dropout(x, training=training) - x = self.out_proj(x) - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.config.hidden_size]) - - -def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - x: tf.Tensor x: - - Returns: tf.Tensor - """ - # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. - mask = tf.cast(input_ids != padding_idx, tf.int64) - incremental_indices = (tf.cumsum(mask, axis=1) + past_key_values_length) * mask - return incremental_indices + padding_idx - - -__all__ = [ - "TFEsmForMaskedLM", - "TFEsmForSequenceClassification", - "TFEsmForTokenClassification", - "TFEsmModel", - "TFEsmPreTrainedModel", -] diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py deleted file mode 100644 index 88b7ae9f0c9d..000000000000 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ /dev/null @@ -1,1343 +0,0 @@ -# coding=utf-8 -# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -TF 2.0 Flaubert model. -""" - -from __future__ import annotations - -import itertools -import random -import warnings -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFSequenceSummary, - TFSharedEmbeddings, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - MULTIPLE_CHOICE_DUMMY_INPUTS, - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_flaubert import FlaubertConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "flaubert/flaubert_base_cased" -_CONFIG_FOR_DOC = "FlaubertConfig" - - -FLAUBERT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`FlaubertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -FLAUBERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - `1` for tokens that are **not masked**, - - `0` for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - langs (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are - languages ids which can be obtained from the language names by using two conversion mappings provided in - the configuration of the model (only provided for multilingual models). More precisely, the *language name - to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the - *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string). - - See usage examples detailed in the [multilingual documentation](../multilingual). - token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - `0` corresponds to a *sentence A* token, - - `1` corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - lengths (`tf.Tensor` or `Numpy array` of shape `(batch_size,)`, *optional*): - Length of each sentence that can be used to avoid performing attention on padding token indices. You can - also use *attention_mask* for the same result (see above), kept here for compatibility Indices selected in - `[0, ..., input_ids.size(-1)]`: - cache (`dict[str, tf.Tensor]`, *optional*): - Dictionary string to `tf.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential - decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - `1` indicates the head is **not masked**, - - `0` indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -def get_masks(slen, lengths, causal, padding_mask=None): - """ - Generate hidden states mask, and optionally an attention mask. - """ - bs = shape_list(lengths)[0] - if padding_mask is not None: - mask = padding_mask - else: - # assert lengths.max().item() <= slen - alen = tf.range(slen, dtype=lengths.dtype) - mask = alen < tf.expand_dims(lengths, axis=1) - - # attention mask is the same as mask, or triangular inferior attention (causal) - if causal: - attn_mask = tf.less_equal( - tf.tile(tf.reshape(alen, (1, 1, slen)), (bs, slen, 1)), tf.reshape(alen, (1, slen, 1)) - ) - else: - attn_mask = mask - - # sanity check - # assert shape_list(mask) == [bs, slen] - tf.debugging.assert_equal(shape_list(mask), [bs, slen]) - if causal: - tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen]) - - return mask, attn_mask - - -class TFFlaubertPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = FlaubertConfig - base_model_prefix = "transformer" - - @property - def dummy_inputs(self): - # Sometimes Flaubert has language embeddings so don't forget to build them as well if needed - inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]], dtype=tf.int32) - attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32) - if self.config.use_lang_emb and self.config.n_langs > 1: - return { - "input_ids": inputs_list, - "attention_mask": attns_list, - "langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32), - } - else: - return {"input_ids": inputs_list, "attention_mask": attns_list} - - -@add_start_docstrings( - "The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.", - FLAUBERT_START_DOCSTRING, -) -class TFFlaubertModel(TFFlaubertPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFFlaubertMainLayer(config, name="transformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutput: - outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - langs=langs, - token_type_ids=token_type_ids, - position_ids=position_ids, - lengths=lengths, - cache=cache, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMMultiHeadAttention with XLM->Flaubert -class TFFlaubertMultiHeadAttention(keras.layers.Layer): - NEW_ID = itertools.count() - - def __init__(self, n_heads, dim, config, **kwargs): - super().__init__(**kwargs) - self.layer_id = next(TFFlaubertMultiHeadAttention.NEW_ID) - self.dim = dim - self.n_heads = n_heads - self.output_attentions = config.output_attentions - assert self.dim % self.n_heads == 0 - - self.q_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin") - self.k_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="k_lin") - self.v_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="v_lin") - self.out_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="out_lin") - self.dropout = keras.layers.Dropout(config.attention_dropout) - self.pruned_heads = set() - self.dim = dim - - def prune_heads(self, heads): - raise NotImplementedError - - def call(self, input, mask, kv, cache, head_mask, output_attentions, training=False): - """ - Self-attention (if kv is None) or attention over source sentence (provided by kv). - """ - # Input is (bs, qlen, dim) - # Mask is (bs, klen) (non-causal) or (bs, klen, klen) - bs, qlen, dim = shape_list(input) - - if kv is None: - klen = qlen if cache is None else cache["slen"] + qlen - else: - klen = shape_list(kv)[1] - - # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' - dim_per_head = self.dim // self.n_heads - mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen) - - def shape(x): - """projection""" - return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3)) - - def unshape(x): - """compute context""" - return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head)) - - q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) - - if kv is None: - k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) - elif cache is None or self.layer_id not in cache: - k = v = kv - k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head) - - if cache is not None: - if self.layer_id in cache: - if kv is None: - k_, v_ = cache[self.layer_id] - k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head) - v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head) - else: - k, v = cache[self.layer_id] - - cache[self.layer_id] = (k, v) - - f_dim_per_head = tf.cast(dim_per_head, dtype=q.dtype) - q = tf.multiply(q, tf.math.rsqrt(f_dim_per_head)) # (bs, n_heads, qlen, dim_per_head) - k = tf.cast(k, dtype=q.dtype) - scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen) - mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) - # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen) - mask = tf.cast(mask, dtype=scores.dtype) - scores = scores - 1e30 * (1.0 - mask) - weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) - weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) - - # Mask heads if we want to - if head_mask is not None: - weights = weights * head_mask - - context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) - context = unshape(context) # (bs, qlen, dim) - outputs = (self.out_lin(context),) - - if output_attentions: - outputs = outputs + (weights,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "q_lin", None) is not None: - with tf.name_scope(self.q_lin.name): - self.q_lin.build([None, None, self.dim]) - if getattr(self, "k_lin", None) is not None: - with tf.name_scope(self.k_lin.name): - self.k_lin.build([None, None, self.dim]) - if getattr(self, "v_lin", None) is not None: - with tf.name_scope(self.v_lin.name): - self.v_lin.build([None, None, self.dim]) - if getattr(self, "out_lin", None) is not None: - with tf.name_scope(self.out_lin.name): - self.out_lin.build([None, None, self.dim]) - - -# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMTransformerFFN -class TFFlaubertTransformerFFN(keras.layers.Layer): - def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs): - super().__init__(**kwargs) - - self.lin1 = keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1") - self.lin2 = keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2") - self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu") - self.dropout = keras.layers.Dropout(config.dropout) - self.in_dim = in_dim - self.dim_hidden = dim_hidden - - def call(self, input, training=False): - x = self.lin1(input) - x = self.act(x) - x = self.lin2(x) - x = self.dropout(x, training=training) - - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "lin1", None) is not None: - with tf.name_scope(self.lin1.name): - self.lin1.build([None, None, self.in_dim]) - if getattr(self, "lin2", None) is not None: - with tf.name_scope(self.lin2.name): - self.lin2.build([None, None, self.dim_hidden]) - - -@keras_serializable -class TFFlaubertMainLayer(keras.layers.Layer): - config_class = FlaubertConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.n_heads = config.n_heads - self.n_langs = config.n_langs - self.dim = config.emb_dim - self.hidden_dim = self.dim * 4 - self.n_words = config.n_words - self.pad_index = config.pad_index - self.causal = config.causal - self.n_layers = config.n_layers - self.use_lang_emb = config.use_lang_emb - self.layerdrop = getattr(config, "layerdrop", 0.0) - self.pre_norm = getattr(config, "pre_norm", False) - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - self.max_position_embeddings = config.max_position_embeddings - self.embed_init_std = config.embed_init_std - self.dropout = keras.layers.Dropout(config.dropout) - self.embeddings = TFSharedEmbeddings( - self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings" - ) - self.layer_norm_emb = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm_emb") - self.attentions = [] - self.layer_norm1 = [] - self.ffns = [] - self.layer_norm2 = [] - - for i in range(self.n_layers): - self.attentions.append( - TFFlaubertMultiHeadAttention(self.n_heads, self.dim, config=config, name=f"attentions_._{i}") - ) - self.layer_norm1.append( - keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name=f"layer_norm1_._{i}") - ) - # if self.is_decoder: - # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) - # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) - self.ffns.append( - TFFlaubertTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name=f"ffns_._{i}") - ) - self.layer_norm2.append( - keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name=f"layer_norm2_._{i}") - ) - - def build(self, input_shape=None): - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.dim], - initializer=get_initializer(self.embed_init_std), - ) - - if self.n_langs > 1 and self.use_lang_emb: - with tf.name_scope("lang_embeddings"): - self.lang_embeddings = self.add_weight( - name="embeddings", - shape=[self.n_langs, self.dim], - initializer=get_initializer(self.embed_init_std), - ) - - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "layer_norm_emb", None) is not None: - with tf.name_scope(self.layer_norm_emb.name): - self.layer_norm_emb.build([None, None, self.dim]) - for layer in self.attentions: - with tf.name_scope(layer.name): - layer.build(None) - for layer in self.layer_norm1: - with tf.name_scope(layer.name): - layer.build([None, None, self.dim]) - for layer in self.ffns: - with tf.name_scope(layer.name): - layer.build(None) - for layer in self.layer_norm2: - with tf.name_scope(layer.name): - layer.build([None, None, self.dim]) - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, value): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - @unpack_inputs - def call( - self, - input_ids: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutput: - # removed: src_enc=None, src_len=None - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - bs, slen = shape_list(input_ids) - elif inputs_embeds is not None: - bs, slen = shape_list(inputs_embeds)[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if lengths is None: - if input_ids is not None: - lengths = tf.reduce_sum( - tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=input_ids.dtype), axis=1 - ) - else: - lengths = tf.convert_to_tensor([slen] * bs) - # mask = input_ids != self.pad_index - - # check inputs - # assert shape_list(lengths)[0] == bs - ( - tf.debugging.assert_equal(shape_list(lengths)[0], bs), - f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched", - ) - # assert lengths.max().item() <= slen - # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 - # assert (src_enc is None) == (src_len is None) - # if src_enc is not None: - # assert self.is_decoder - # assert src_enc.size(0) == bs - - # generate masks - mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) - # if self.is_decoder and src_enc is not None: - # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] - - # position_ids - if position_ids is None: - position_ids = tf.expand_dims(tf.range(slen), axis=0) - position_ids = tf.tile(position_ids, (bs, 1)) - - # assert shape_list(position_ids) == [bs, slen] # (slen, bs) - ( - tf.debugging.assert_equal(shape_list(position_ids), [bs, slen]), - f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched", - ) - # position_ids = position_ids.transpose(0, 1) - - # langs - if langs is not None: - # assert shape_list(langs) == [bs, slen] # (slen, bs) - ( - tf.debugging.assert_equal(shape_list(langs), [bs, slen]), - f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched", - ) - # langs = langs.transpose(0, 1) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.n_layers - - # do not recompute cached elements - if cache is not None and input_ids is not None: - _slen = slen - cache["slen"] - input_ids = input_ids[:, -_slen:] - position_ids = position_ids[:, -_slen:] - if langs is not None: - langs = langs[:, -_slen:] - mask = mask[:, -_slen:] - attn_mask = attn_mask[:, -_slen:] - - # embeddings - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embeddings.vocab_size) - inputs_embeds = self.embeddings(input_ids) - - tensor = inputs_embeds + tf.gather(self.position_embeddings, position_ids) - - if langs is not None and self.use_lang_emb: - tensor = tensor + tf.gather(self.lang_embeddings, langs) - if token_type_ids is not None: - tensor = tensor + self.embeddings(token_type_ids) - - tensor = self.layer_norm_emb(tensor) - tensor = self.dropout(tensor, training=training) - mask = tf.cast(mask, dtype=tensor.dtype) - tensor = tensor * tf.expand_dims(mask, axis=-1) - - # hidden_states and attentions cannot be None in graph mode. - hidden_states = () if output_hidden_states else None - attentions = () if output_attentions else None - - # transformer layers - for i in range(self.n_layers): - # LayerDrop - dropout_probability = random.uniform(0, 1) - - if training and (dropout_probability < self.layerdrop): - continue - - if output_hidden_states: - hidden_states = hidden_states + (tensor,) - - # self attention - if not self.pre_norm: - attn_outputs = self.attentions[i]( - tensor, - attn_mask, - None, - cache, - head_mask[i], - output_attentions, - training=training, - ) - attn = attn_outputs[0] - - if output_attentions: - attentions = attentions + (attn_outputs[1],) - - attn = self.dropout(attn, training=training) - tensor = tensor + attn - tensor = self.layer_norm1[i](tensor) - else: - tensor_normalized = self.layer_norm1[i](tensor) - attn_outputs = self.attentions[i]( - tensor_normalized, - attn_mask, - None, - cache, - head_mask[i], - output_attentions, - training=training, - ) - attn = attn_outputs[0] - - if output_attentions: - attentions = attentions + (attn_outputs[1],) - - attn = self.dropout(attn, training=training) - tensor = tensor + attn - - # encoder attention (for decoder only) - # if self.is_decoder and src_enc is not None: - # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) - # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training) - # tensor = tensor + attn - # tensor = self.layer_norm15[i](tensor) - - # FFN - if not self.pre_norm: - tensor = tensor + self.ffns[i](tensor) - tensor = self.layer_norm2[i](tensor) - else: - tensor_normalized = self.layer_norm2[i](tensor) - tensor = tensor + self.ffns[i](tensor_normalized) - - tensor = tensor * tf.expand_dims(mask, axis=-1) - - # Add last hidden state - if output_hidden_states: - hidden_states = hidden_states + (tensor,) - - # update cache length - if cache is not None: - cache["slen"] += tensor.size(1) - - # move back sequence length to dimension 0 - # tensor = tensor.transpose(0, 1) - - if not return_dict: - return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) - - return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) - - -# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMPredLayer -class TFFlaubertPredLayer(keras.layers.Layer): - """ - Prediction layer (cross_entropy or adaptive_softmax). - """ - - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - - self.asm = config.asm - self.n_words = config.n_words - self.pad_index = config.pad_index - - if config.asm is False: - self.input_embeddings = input_embeddings - else: - raise NotImplementedError - # self.proj = nn.AdaptiveLogSoftmaxWithLoss( - # in_features=dim, - # n_classes=config.n_words, - # cutoffs=config.asm_cutoffs, - # div_value=config.asm_div_value, - # head_bias=True, # default is False - # ) - - def build(self, input_shape): - # The output weights are the same as the input embeddings, but there is an output-only bias for each token. - self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias") - - super().build(input_shape) - - def get_output_embeddings(self): - return self.input_embeddings - - def set_output_embeddings(self, value): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - hidden_states = self.input_embeddings(hidden_states, mode="linear") - hidden_states = hidden_states + self.bias - - return hidden_states - - -@dataclass -class TFFlaubertWithLMHeadModelOutput(ModelOutput): - """ - Base class for [`TFFlaubertWithLMHeadModel`] outputs. - - Args: - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@add_start_docstrings( - """ - The Flaubert Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - FLAUBERT_START_DOCSTRING, -) -class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFFlaubertMainLayer(config, name="transformer") - self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj") - # Flaubert does not have past caching features - self.supports_xla_generation = False - - def get_lm_head(self): - return self.pred_layer - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.pred_layer.name - - def prepare_inputs_for_generation(self, inputs, **kwargs): - mask_token_id = self.config.mask_token_id - lang_id = self.config.lang_id - - effective_batch_size = inputs.shape[0] - mask_token = tf.fill((effective_batch_size, 1), 1) * mask_token_id - inputs = tf.concat([inputs, mask_token], axis=1) - - if lang_id is not None: - langs = tf.ones_like(inputs) * lang_id - else: - langs = None - return {"input_ids": inputs, "langs": langs} - - @unpack_inputs - @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFFlaubertWithLMHeadModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFFlaubertWithLMHeadModelOutput: - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - langs=langs, - token_type_ids=token_type_ids, - position_ids=position_ids, - lengths=lengths, - cache=cache, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - output = transformer_outputs[0] - outputs = self.pred_layer(output) - - if not return_dict: - return (outputs,) + transformer_outputs[1:] - - return TFFlaubertWithLMHeadModelOutput( - logits=outputs, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "pred_layer", None) is not None: - with tf.name_scope(self.pred_layer.name): - self.pred_layer.build(None) - - -@add_start_docstrings( - """ - Flaubert Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) - e.g. for GLUE tasks. - """, - FLAUBERT_START_DOCSTRING, -) -# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert -class TFFlaubertForSequenceClassification(TFFlaubertPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.transformer = TFFlaubertMainLayer(config, name="transformer") - self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary") - - @unpack_inputs - @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - langs=langs, - token_type_ids=token_type_ids, - position_ids=position_ids, - lengths=lengths, - cache=cache, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - output = transformer_outputs[0] - - logits = self.sequence_summary(output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "sequence_summary", None) is not None: - with tf.name_scope(self.sequence_summary.name): - self.sequence_summary.build(None) - - -@add_start_docstrings( - """ - Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - FLAUBERT_START_DOCSTRING, -) -# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForQuestionAnsweringSimple with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert -class TFFlaubertForQuestionAnsweringSimple(TFFlaubertPreTrainedModel, TFQuestionAnsweringLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFFlaubertMainLayer(config, name="transformer") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.init_std), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - langs=langs, - token_type_ids=token_type_ids, - position_ids=position_ids, - lengths=lengths, - cache=cache, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = transformer_outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - - loss = None - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Flaubert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - FLAUBERT_START_DOCSTRING, -) -# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForTokenClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert -class TFFlaubertForTokenClassification(TFFlaubertPreTrainedModel, TFTokenClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.transformer = TFFlaubertMainLayer(config, name="transformer") - self.dropout = keras.layers.Dropout(config.dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - langs=langs, - token_type_ids=token_type_ids, - position_ids=position_ids, - lengths=lengths, - cache=cache, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = transformer_outputs[0] - - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Flaubert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - FLAUBERT_START_DOCSTRING, -) -# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert -class TFFlaubertForMultipleChoice(TFFlaubertPreTrainedModel, TFMultipleChoiceLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.transformer = TFFlaubertMainLayer(config, name="transformer") - self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary") - self.logits_proj = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj" - ) - self.config = config - - @property - def dummy_inputs(self): - """ - Dummy inputs to build the network. - - Returns: - tf.Tensor with dummy inputs - """ - # Sometimes Flaubert has language embeddings so don't forget to build them as well if needed - if self.config.use_lang_emb and self.config.n_langs > 1: - return { - "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32), - "langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32), - } - else: - return { - "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32), - } - - @unpack_inputs - @add_start_docstrings_to_model_forward( - FLAUBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") - ) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None - flat_inputs_embeds = ( - tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - - if lengths is not None: - logger.warning( - "The `lengths` parameter cannot be used with the Flaubert multiple choice models. Please use the " - "attention mask instead.", - ) - lengths = None - - transformer_outputs = self.transformer( - flat_input_ids, - flat_attention_mask, - flat_langs, - flat_token_type_ids, - flat_position_ids, - lengths, - cache, - head_mask, - flat_inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - output = transformer_outputs[0] - logits = self.sequence_summary(output) - logits = self.logits_proj(logits) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "sequence_summary", None) is not None: - with tf.name_scope(self.sequence_summary.name): - self.sequence_summary.build(None) - if getattr(self, "logits_proj", None) is not None: - with tf.name_scope(self.logits_proj.name): - self.logits_proj.build([None, None, self.config.num_labels]) - - -__all__ = [ - "TFFlaubertForMultipleChoice", - "TFFlaubertForQuestionAnsweringSimple", - "TFFlaubertForSequenceClassification", - "TFFlaubertForTokenClassification", - "TFFlaubertModel", - "TFFlaubertPreTrainedModel", - "TFFlaubertWithLMHeadModel", -] diff --git a/src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py deleted file mode 100644 index 71660354db14..000000000000 --- a/src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py +++ /dev/null @@ -1,156 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert FNet checkpoint.""" - -import argparse - -import torch -from flax.training.checkpoints import restore_checkpoint - -from transformers import FNetConfig, FNetForPreTraining -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, fnet_config_file, save_path): - # Initialise PyTorch model - config = FNetConfig.from_json_file(fnet_config_file) - print(f"Building PyTorch model from configuration: {config}") - fnet_pretraining_model = FNetForPreTraining(config) - - checkpoint_dict = restore_checkpoint(flax_checkpoint_path, None) - pretrained_model_params = checkpoint_dict["target"] - - # Embeddings - # Position IDs - state_dict = fnet_pretraining_model.state_dict() - - position_ids = state_dict["fnet.embeddings.position_ids"] - new_state_dict = {"fnet.embeddings.position_ids": position_ids} - # Embedding Layers - new_state_dict["fnet.embeddings.word_embeddings.weight"] = torch.tensor( - pretrained_model_params["encoder"]["embedder"]["word"]["embedding"] - ) - new_state_dict["fnet.embeddings.position_embeddings.weight"] = torch.tensor( - pretrained_model_params["encoder"]["embedder"]["position"]["embedding"][0] - ) - new_state_dict["fnet.embeddings.token_type_embeddings.weight"] = torch.tensor( - pretrained_model_params["encoder"]["embedder"]["type"]["embedding"] - ) - new_state_dict["fnet.embeddings.projection.weight"] = torch.tensor( - pretrained_model_params["encoder"]["embedder"]["hidden_mapping_in"]["kernel"] - ).T - new_state_dict["fnet.embeddings.projection.bias"] = torch.tensor( - pretrained_model_params["encoder"]["embedder"]["hidden_mapping_in"]["bias"] - ) - new_state_dict["fnet.embeddings.LayerNorm.weight"] = torch.tensor( - pretrained_model_params["encoder"]["embedder"]["layer_norm"]["scale"] - ) - new_state_dict["fnet.embeddings.LayerNorm.bias"] = torch.tensor( - pretrained_model_params["encoder"]["embedder"]["layer_norm"]["bias"] - ) - - # Encoder Layers - for layer in range(config.num_hidden_layers): - new_state_dict[f"fnet.encoder.layer.{layer}.fourier.output.LayerNorm.weight"] = torch.tensor( - pretrained_model_params["encoder"][f"encoder_{layer}"]["mixing_layer_norm"]["scale"] - ) - new_state_dict[f"fnet.encoder.layer.{layer}.fourier.output.LayerNorm.bias"] = torch.tensor( - pretrained_model_params["encoder"][f"encoder_{layer}"]["mixing_layer_norm"]["bias"] - ) - - new_state_dict[f"fnet.encoder.layer.{layer}.intermediate.dense.weight"] = torch.tensor( - pretrained_model_params["encoder"][f"feed_forward_{layer}"]["intermediate"]["kernel"] - ).T - new_state_dict[f"fnet.encoder.layer.{layer}.intermediate.dense.bias"] = torch.tensor( - pretrained_model_params["encoder"][f"feed_forward_{layer}"]["intermediate"]["bias"] - ) - - new_state_dict[f"fnet.encoder.layer.{layer}.output.dense.weight"] = torch.tensor( - pretrained_model_params["encoder"][f"feed_forward_{layer}"]["output"]["kernel"] - ).T - new_state_dict[f"fnet.encoder.layer.{layer}.output.dense.bias"] = torch.tensor( - pretrained_model_params["encoder"][f"feed_forward_{layer}"]["output"]["bias"] - ) - - new_state_dict[f"fnet.encoder.layer.{layer}.output.LayerNorm.weight"] = torch.tensor( - pretrained_model_params["encoder"][f"encoder_{layer}"]["output_layer_norm"]["scale"] - ) - new_state_dict[f"fnet.encoder.layer.{layer}.output.LayerNorm.bias"] = torch.tensor( - pretrained_model_params["encoder"][f"encoder_{layer}"]["output_layer_norm"]["bias"] - ) - - # Pooler Layers - new_state_dict["fnet.pooler.dense.weight"] = torch.tensor(pretrained_model_params["encoder"]["pooler"]["kernel"]).T - new_state_dict["fnet.pooler.dense.bias"] = torch.tensor(pretrained_model_params["encoder"]["pooler"]["bias"]) - - # Masked LM Layers - new_state_dict["cls.predictions.transform.dense.weight"] = torch.tensor( - pretrained_model_params["predictions_dense"]["kernel"] - ).T - new_state_dict["cls.predictions.transform.dense.bias"] = torch.tensor( - pretrained_model_params["predictions_dense"]["bias"] - ) - new_state_dict["cls.predictions.transform.LayerNorm.weight"] = torch.tensor( - pretrained_model_params["predictions_layer_norm"]["scale"] - ) - new_state_dict["cls.predictions.transform.LayerNorm.bias"] = torch.tensor( - pretrained_model_params["predictions_layer_norm"]["bias"] - ) - new_state_dict["cls.predictions.decoder.weight"] = torch.tensor( - pretrained_model_params["encoder"]["embedder"]["word"]["embedding"] - ) - new_state_dict["cls.predictions.decoder.bias"] = torch.tensor( - pretrained_model_params["predictions_output"]["output_bias"] - ) - new_state_dict["cls.predictions.bias"] = torch.tensor(pretrained_model_params["predictions_output"]["output_bias"]) - - # Seq Relationship Layers - new_state_dict["cls.seq_relationship.weight"] = torch.tensor( - pretrained_model_params["classification"]["output_kernel"] - ) - new_state_dict["cls.seq_relationship.bias"] = torch.tensor( - pretrained_model_params["classification"]["output_bias"] - ) - - # Load State Dict - fnet_pretraining_model.load_state_dict(new_state_dict) - - # Save PreTrained - print(f"Saving pretrained model to {save_path}") - fnet_pretraining_model.save_pretrained(save_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--flax_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--fnet_config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained FNet model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument("--save_path", default=None, type=str, required=True, help="Path to the output model.") - args = parser.parse_args() - convert_flax_checkpoint_to_pytorch(args.flax_checkpoint_path, args.fnet_config_file, args.save_path) diff --git a/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py deleted file mode 100755 index 4eab188f2ab7..000000000000 --- a/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,64 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert Funnel checkpoint.""" - -import argparse - -import torch - -from transformers import FunnelBaseModel, FunnelConfig, FunnelModel, load_tf_weights_in_funnel -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, base_model): - # Initialise PyTorch model - config = FunnelConfig.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") - model = FunnelBaseModel(config) if base_model else FunnelModel(config) - - # Load weights from tf checkpoint - load_tf_weights_in_funnel(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - torch.save(model.state_dict(), pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.", - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - parser.add_argument( - "--base_model", action="store_true", help="Whether you want just the base model (no decoder) or not." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch( - args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.base_model - ) diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py deleted file mode 100644 index 3d57fa99eaa1..000000000000 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ /dev/null @@ -1,1883 +0,0 @@ -# coding=utf-8 -# Copyright 2020-present Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 Funnel model.""" - -from __future__ import annotations - -import warnings -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_funnel import FunnelConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "FunnelConfig" - - -INF = 1e6 - - -class TFFunnelEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - self.initializer_std = 1.0 if config.initializer_std is None else config.initializer_std - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.hidden_size], - initializer=get_initializer(initializer_range=self.initializer_std), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.d_model]) - - def call(self, input_ids=None, inputs_embeds=None, training=False): - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - assert not (input_ids is not None and inputs_embeds is not None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(self.weight, input_ids) - - final_embeddings = self.LayerNorm(inputs=inputs_embeds) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -class TFFunnelAttentionStructure: - """ - Contains helpers for `TFFunnelRelMultiheadAttention `. - """ - - cls_token_type_id: int = 2 - - def __init__(self, config): - self.d_model = config.d_model - self.attention_type = config.attention_type - self.num_blocks = config.num_blocks - self.separate_cls = config.separate_cls - self.truncate_seq = config.truncate_seq - self.pool_q_only = config.pool_q_only - self.pooling_type = config.pooling_type - - self.sin_dropout = keras.layers.Dropout(config.hidden_dropout) - self.cos_dropout = keras.layers.Dropout(config.hidden_dropout) - # Track where we are at in terms of pooling from the original input, e.g., by how much the sequence length was - # divided. - self.pooling_mult = None - - def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_ids=None, training=False): - """Returns the attention inputs associated to the inputs of the model.""" - # inputs_embeds has shape batch_size x seq_len x d_model - # attention_mask and token_type_ids have shape batch_size x seq_len - self.pooling_mult = 1 - self.seq_len = seq_len = shape_list(inputs_embeds)[1] - position_embeds = self.get_position_embeds(seq_len, training=training) - token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None - cls_mask = ( - tf.pad(tf.ones([seq_len - 1, seq_len - 1], dtype=inputs_embeds.dtype), [[1, 0], [1, 0]]) - if self.separate_cls - else None - ) - return (position_embeds, token_type_mat, attention_mask, cls_mask) - - def token_type_ids_to_mat(self, token_type_ids): - """Convert `token_type_ids` to `token_type_mat`.""" - token_type_mat = tf.equal(tf.expand_dims(token_type_ids, -1), tf.expand_dims(token_type_ids, -2)) - # Treat as in the same segment as both A & B - cls_ids = tf.equal(token_type_ids, tf.constant([self.cls_token_type_id], dtype=token_type_ids.dtype)) - cls_mat = tf.logical_or(tf.expand_dims(cls_ids, -1), tf.expand_dims(cls_ids, -2)) - return tf.logical_or(cls_mat, token_type_mat) - - def get_position_embeds(self, seq_len, training=False): - """ - Create and cache inputs related to relative position encoding. Those are very different depending on whether we - are using the factorized or the relative shift attention: - - For the factorized attention, it returns the matrices (phi, pi, psi, omega) used in the paper, appendix A.2.2, - final formula. - - For the relative shift attention, it returns all possible vectors R used in the paper, appendix A.2.1, final - formula. - - Paper link: https://huggingface.co/papers/2006.03236 - """ - if self.attention_type == "factorized": - # Notations from the paper, appending A.2.2, final formula. - # We need to create and return the matrices phi, psi, pi and omega. - pos_seq = tf.range(0, seq_len, 1.0) - freq_seq = tf.range(0, self.d_model // 2, 1.0) - inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2))) - sinusoid = tf.einsum("i,d->id", pos_seq, inv_freq) - - sin_embed = tf.sin(sinusoid) - sin_embed_d = self.sin_dropout(sin_embed, training=training) - cos_embed = tf.cos(sinusoid) - cos_embed_d = self.cos_dropout(cos_embed, training=training) - # This is different from the formula on the paper... - phi = tf.concat([sin_embed_d, sin_embed_d], axis=-1) - psi = tf.concat([cos_embed, sin_embed], axis=-1) - pi = tf.concat([cos_embed_d, cos_embed_d], axis=-1) - omega = tf.concat([-sin_embed, cos_embed], axis=-1) - return (phi, pi, psi, omega) - else: - # Notations from the paper, appending A.2.1, final formula. - # We need to create and return all the possible vectors R for all blocks and shifts. - freq_seq = tf.range(0, self.d_model // 2, 1.0) - inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2))) - # Maximum relative positions for the first input - rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0) - zero_offset = seq_len * tf.constant(2) - sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq) - sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training) - cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training) - pos_embed = tf.concat([sin_embed, cos_embed], axis=-1) - - pos = tf.range(0, seq_len) - pooled_pos = pos - position_embeds_list = [] - for block_index in range(0, self.num_blocks): - # For each block with block_index > 0, we need two types position embeddings: - # - Attention(pooled-q, unpooled-kv) - # - Attention(pooled-q, pooled-kv) - # For block_index = 0 we only need the second one and leave the first one as None. - - # First type - position_embeds_pooling = tf.fill([1], value=-1.0) - - if block_index != 0: - pooled_pos = self.stride_pool_pos(pos, block_index) - - # construct rel_pos_id - stride = 2 ** (block_index - 1) - rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2) - # rel_pos = tf.expand_dims(rel_pos,1) + zero_offset - # rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model)) - rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype) - rel_pos = rel_pos + zero_offset - position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0) - - # Second type - pos = pooled_pos - stride = 2**block_index - rel_pos = self.relative_pos(pos, stride) - - # rel_pos = tf.expand_dims(rel_pos,1) + zero_offset - # rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model)) - rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype) - rel_pos = rel_pos + zero_offset - tf.debugging.assert_less(rel_pos, tf.shape(pos_embed)[0]) - position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0) - - position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling]) - return position_embeds_list - - def stride_pool_pos(self, pos_id, block_index): - """ - Pool `pos_id` while keeping the cls token separate (if `self.separate_cls=True`). - """ - if self.separate_cls: - # Under separate , we treat the as the first token in - # the previous block of the 1st real block. Since the 1st real - # block always has position 1, the position of the previous block - # will be at `1 - 2 ** block_index`. - cls_pos = tf.constant([-(2**block_index) + 1], dtype=pos_id.dtype) - pooled_pos_id = pos_id[1:-1] if self.truncate_seq else pos_id[1:] - return tf.concat([cls_pos, pooled_pos_id[::2]], 0) - else: - return pos_id[::2] - - def relative_pos(self, pos, stride, pooled_pos=None, shift=1): - """ - Build the relative positional vector between `pos` and `pooled_pos`. - """ - if pooled_pos is None: - pooled_pos = pos - - ref_point = pooled_pos[0] - pos[0] - num_remove = shift * shape_list(pooled_pos)[0] - max_dist = ref_point + num_remove * stride - min_dist = pooled_pos[0] - pos[-1] - - return tf.range(max_dist, min_dist - 1, -stride) - - def stride_pool(self, tensor, axis): - """ - Perform pooling by stride slicing the tensor along the given axis. - """ - if tensor is None: - return None - - # Do the stride pool recursively if axis is a list or a tuple of ints. - if isinstance(axis, (list, tuple)): - for ax in axis: - tensor = self.stride_pool(tensor, ax) - return tensor - - # Do the stride pool recursively if tensor is a list or tuple of tensors. - if isinstance(tensor, (tuple, list)): - return type(tensor)(self.stride_pool(x, axis) for x in tensor) - - # Deal with negative axis - axis %= len(shape_list(tensor)) - - axis_slice = slice(None, -1, 2) if self.separate_cls and self.truncate_seq else slice(None, None, 2) - enc_slice = [slice(None)] * axis + [axis_slice] - if self.separate_cls: - cls_slice = [slice(None)] * axis + [slice(None, 1)] - tensor = tf.concat([tensor[cls_slice], tensor], axis) - return tensor[enc_slice] - - def pool_tensor(self, tensor, mode="mean", stride=2): - """Apply 1D pooling to a tensor of size [B x T (x H)].""" - if tensor is None: - return None - - # Do the pool recursively if tensor is a list or tuple of tensors. - if isinstance(tensor, (tuple, list)): - return type(tensor)(self.pool_tensor(tensor, mode=mode, stride=stride) for x in tensor) - - if self.separate_cls: - suffix = tensor[:, :-1] if self.truncate_seq else tensor - tensor = tf.concat([tensor[:, :1], suffix], axis=1) - - ndim = len(shape_list(tensor)) - if ndim == 2: - tensor = tensor[:, :, None] - - if mode == "mean": - tensor = tf.nn.avg_pool1d(tensor, stride, strides=stride, data_format="NWC", padding="SAME") - elif mode == "max": - tensor = tf.nn.max_pool1d(tensor, stride, strides=stride, data_format="NWC", padding="SAME") - elif mode == "min": - tensor = -tf.nn.max_pool1d(-tensor, stride, strides=stride, data_format="NWC", padding="SAME") - else: - raise NotImplementedError("The supported modes are 'mean', 'max' and 'min'.") - - return tf.squeeze(tensor, 2) if ndim == 2 else tensor - - def pre_attention_pooling(self, output, attention_inputs): - """Pool `output` and the proper parts of `attention_inputs` before the attention layer.""" - position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs - if self.pool_q_only: - if self.attention_type == "factorized": - position_embeds = self.stride_pool(position_embeds[:2], 0) + position_embeds[2:] - token_type_mat = self.stride_pool(token_type_mat, 1) - cls_mask = self.stride_pool(cls_mask, 0) - output = self.pool_tensor(output, mode=self.pooling_type) - else: - self.pooling_mult *= 2 - if self.attention_type == "factorized": - position_embeds = self.stride_pool(position_embeds, 0) - token_type_mat = self.stride_pool(token_type_mat, [1, 2]) - cls_mask = self.stride_pool(cls_mask, [1, 2]) - attention_mask = self.pool_tensor(attention_mask, mode="min") - output = self.pool_tensor(output, mode=self.pooling_type) - attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask) - return output, attention_inputs - - def post_attention_pooling(self, attention_inputs): - """Pool the proper parts of `attention_inputs` after the attention layer.""" - position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs - if self.pool_q_only: - self.pooling_mult *= 2 - if self.attention_type == "factorized": - position_embeds = position_embeds[:2] + self.stride_pool(position_embeds[2:], 0) - token_type_mat = self.stride_pool(token_type_mat, 2) - cls_mask = self.stride_pool(cls_mask, 1) - attention_mask = self.pool_tensor(attention_mask, mode="min") - attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask) - return attention_inputs - - -def _relative_shift_gather(positional_attn, context_len, shift): - batch_size, n_head, seq_len, max_rel_len = shape_list(positional_attn) - # max_rel_len = 2 * context_len + shift -1 is the numbers of possible relative positions i-j - - # What's next is the same as doing the following gather in PyTorch, which might be clearer code but less efficient. - # idxs = context_len + torch.arange(0, context_len).unsqueeze(0) - torch.arange(0, seq_len).unsqueeze(1) - # # matrix of context_len + i-j - # return positional_attn.gather(3, idxs.expand([batch_size, n_head, context_len, context_len])) - - positional_attn = tf.reshape(positional_attn, [batch_size, n_head, max_rel_len, seq_len]) - positional_attn = positional_attn[:, :, shift:, :] - positional_attn = tf.reshape(positional_attn, [batch_size, n_head, seq_len, max_rel_len - shift]) - positional_attn = positional_attn[..., :context_len] - return positional_attn - - -class TFFunnelRelMultiheadAttention(keras.layers.Layer): - def __init__(self, config, block_index, **kwargs): - super().__init__(**kwargs) - self.attention_type = config.attention_type - self.n_head = n_head = config.n_head - self.d_head = d_head = config.d_head - self.d_model = d_model = config.d_model - self.initializer_range = config.initializer_range - self.block_index = block_index - - self.hidden_dropout = keras.layers.Dropout(config.hidden_dropout) - self.attention_dropout = keras.layers.Dropout(config.attention_dropout) - - initializer = get_initializer(config.initializer_range) - - self.q_head = keras.layers.Dense( - n_head * d_head, use_bias=False, kernel_initializer=initializer, name="q_head" - ) - self.k_head = keras.layers.Dense(n_head * d_head, kernel_initializer=initializer, name="k_head") - self.v_head = keras.layers.Dense(n_head * d_head, kernel_initializer=initializer, name="v_head") - - self.post_proj = keras.layers.Dense(d_model, kernel_initializer=initializer, name="post_proj") - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.scale = 1.0 / (d_head**0.5) - - def build(self, input_shape=None): - n_head, d_head, d_model = self.n_head, self.d_head, self.d_model - initializer = get_initializer(self.initializer_range) - - self.r_w_bias = self.add_weight( - shape=(n_head, d_head), initializer=initializer, trainable=True, name="r_w_bias" - ) - self.r_r_bias = self.add_weight( - shape=(n_head, d_head), initializer=initializer, trainable=True, name="r_r_bias" - ) - self.r_kernel = self.add_weight( - shape=(d_model, n_head, d_head), initializer=initializer, trainable=True, name="r_kernel" - ) - self.r_s_bias = self.add_weight( - shape=(n_head, d_head), initializer=initializer, trainable=True, name="r_s_bias" - ) - self.seg_embed = self.add_weight( - shape=(2, n_head, d_head), initializer=initializer, trainable=True, name="seg_embed" - ) - - if self.built: - return - self.built = True - if getattr(self, "q_head", None) is not None: - with tf.name_scope(self.q_head.name): - self.q_head.build([None, None, d_model]) - if getattr(self, "k_head", None) is not None: - with tf.name_scope(self.k_head.name): - self.k_head.build([None, None, d_model]) - if getattr(self, "v_head", None) is not None: - with tf.name_scope(self.v_head.name): - self.v_head.build([None, None, d_model]) - if getattr(self, "post_proj", None) is not None: - with tf.name_scope(self.post_proj.name): - self.post_proj.build([None, None, n_head * d_head]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, d_model]) - - def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None): - """Relative attention score for the positional encodings""" - # q_head has shape batch_size x sea_len x n_head x d_head - if self.attention_type == "factorized": - # Notations from the paper, appending A.2.2, final formula (https://huggingface.co/papers/2006.03236) - # phi and pi have shape seq_len x d_model, psi and omega have shape context_len x d_model - phi, pi, psi, omega = position_embeds - # Shape n_head x d_head - u = self.r_r_bias * self.scale - # Shape d_model x n_head x d_head - w_r = self.r_kernel - - # Shape batch_size x sea_len x n_head x d_model - q_r_attention = tf.einsum("binh,dnh->bind", q_head + u, w_r) - q_r_attention_1 = q_r_attention * phi[:, None] - q_r_attention_2 = q_r_attention * pi[:, None] - - # Shape batch_size x n_head x seq_len x context_len - positional_attn = tf.einsum("bind,jd->bnij", q_r_attention_1, psi) + tf.einsum( - "bind,jd->bnij", q_r_attention_2, omega - ) - else: - # Notations from the paper, appending A.2.1, final formula (https://huggingface.co/papers/2006.03236) - # Grab the proper positional encoding, shape max_rel_len x d_model - if shape_list(q_head)[1] != context_len: - shift = 2 - r = position_embeds[self.block_index][1] - else: - shift = 1 - r = position_embeds[self.block_index][0] - # Shape n_head x d_head - v = self.r_r_bias * self.scale - # Shape d_model x n_head x d_head - w_r = self.r_kernel - - # Shape max_rel_len x n_head x d_model - r_head = tf.einsum("td,dnh->tnh", r, w_r) - # Shape batch_size x n_head x seq_len x max_rel_len - positional_attn = tf.einsum("binh,tnh->bnit", q_head + v, r_head) - # Shape batch_size x n_head x seq_len x context_len - positional_attn = _relative_shift_gather(positional_attn, context_len, shift) - - if cls_mask is not None: - positional_attn *= cls_mask - return positional_attn - - def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None): - """Relative attention score for the token_type_ids""" - if token_type_mat is None: - return 0 - batch_size, seq_len, context_len = shape_list(token_type_mat) - # q_head has shape batch_size x seq_len x n_head x d_head - # Shape n_head x d_head - r_s_bias = self.r_s_bias * self.scale - - # Shape batch_size x n_head x seq_len x 2 - token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed) - # Shape batch_size x n_head x seq_len x context_len - token_type_mat = tf.tile(token_type_mat[:, None], [1, shape_list(q_head)[2], 1, 1]) - # token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape) - # Shapes batch_size x n_head x seq_len - diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1) - # Shape batch_size x n_head x seq_len x context_len - token_type_attn = tf.where( - token_type_mat, - tf.tile(same_token_type, [1, 1, 1, context_len]), - tf.tile(diff_token_type, [1, 1, 1, context_len]), - ) - - if cls_mask is not None: - token_type_attn *= cls_mask - return token_type_attn - - def call(self, query, key, value, attention_inputs, output_attentions=False, training=False): - # query has shape batch_size x seq_len x d_model - # key and value have shapes batch_size x context_len x d_model - position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs - - batch_size, seq_len, _ = shape_list(query) - context_len = shape_list(key)[1] - n_head, d_head = self.n_head, self.d_head - - # Shape batch_size x seq_len x n_head x d_head - q_head = tf.reshape(self.q_head(query), [batch_size, seq_len, n_head, d_head]) - # Shapes batch_size x context_len x n_head x d_head - k_head = tf.reshape(self.k_head(key), [batch_size, context_len, n_head, d_head]) - v_head = tf.reshape(self.v_head(value), [batch_size, context_len, n_head, d_head]) - - q_head = q_head * self.scale - # Shape n_head x d_head - r_w_bias = self.r_w_bias * self.scale - # Shapes batch_size x n_head x seq_len x context_len - content_score = tf.einsum("bind,bjnd->bnij", q_head + r_w_bias, k_head) - positional_attn = self.relative_positional_attention(position_embeds, q_head, context_len, cls_mask) - token_type_attn = self.relative_token_type_attention(token_type_mat, q_head, cls_mask) - - # merge attention scores - attn_score = content_score + positional_attn + token_type_attn - - # perform masking - if attention_mask is not None: - attention_mask = tf.cast(attention_mask, dtype=attn_score.dtype) - attn_score = attn_score - (INF * (1 - attention_mask[:, None, None])) - - # attention probability - attn_prob = stable_softmax(attn_score, axis=-1) - attn_prob = self.attention_dropout(attn_prob, training=training) - - # attention output, shape batch_size x seq_len x n_head x d_head - attn_vec = tf.einsum("bnij,bjnd->bind", attn_prob, v_head) - - # Shape shape batch_size x seq_len x d_model - attn_out = self.post_proj(tf.reshape(attn_vec, [batch_size, seq_len, n_head * d_head])) - attn_out = self.hidden_dropout(attn_out, training=training) - - output = self.layer_norm(query + attn_out) - return (output, attn_prob) if output_attentions else (output,) - - -class TFFunnelPositionwiseFFN(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - initializer = get_initializer(config.initializer_range) - self.linear_1 = keras.layers.Dense(config.d_inner, kernel_initializer=initializer, name="linear_1") - self.activation_function = get_tf_activation(config.hidden_act) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - self.linear_2 = keras.layers.Dense(config.d_model, kernel_initializer=initializer, name="linear_2") - self.dropout = keras.layers.Dropout(config.hidden_dropout) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.config = config - - def call(self, hidden, training=False): - h = self.linear_1(hidden) - h = self.activation_function(h) - h = self.activation_dropout(h, training=training) - h = self.linear_2(h) - h = self.dropout(h, training=training) - return self.layer_norm(hidden + h) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "linear_1", None) is not None: - with tf.name_scope(self.linear_1.name): - self.linear_1.build([None, None, self.config.d_model]) - if getattr(self, "linear_2", None) is not None: - with tf.name_scope(self.linear_2.name): - self.linear_2.build([None, None, self.config.d_inner]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - - -class TFFunnelLayer(keras.layers.Layer): - def __init__(self, config, block_index, **kwargs): - super().__init__(**kwargs) - self.attention = TFFunnelRelMultiheadAttention(config, block_index, name="attention") - self.ffn = TFFunnelPositionwiseFFN(config, name="ffn") - - def call(self, query, key, value, attention_inputs, output_attentions=False, training=False): - attn = self.attention( - query, key, value, attention_inputs, output_attentions=output_attentions, training=training - ) - output = self.ffn(attn[0], training=training) - return (output, attn[1]) if output_attentions else (output,) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "ffn", None) is not None: - with tf.name_scope(self.ffn.name): - self.ffn.build(None) - - -class TFFunnelEncoder(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.separate_cls = config.separate_cls - self.pool_q_only = config.pool_q_only - self.block_repeats = config.block_repeats - self.attention_structure = TFFunnelAttentionStructure(config) - self.blocks = [ - [TFFunnelLayer(config, block_index, name=f"blocks_._{block_index}_._{i}") for i in range(block_size)] - for block_index, block_size in enumerate(config.block_sizes) - ] - - def call( - self, - inputs_embeds, - attention_mask=None, - token_type_ids=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - training=False, - ): - # The pooling is not implemented on long tensors, so we convert this mask. - # attention_mask = tf.cast(attention_mask, inputs_embeds.dtype) - attention_inputs = self.attention_structure.init_attention_inputs( - inputs_embeds, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - training=training, - ) - hidden = inputs_embeds - - all_hidden_states = (inputs_embeds,) if output_hidden_states else None - all_attentions = () if output_attentions else None - - for block_index, block in enumerate(self.blocks): - pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1) - pooling_flag = pooling_flag and block_index > 0 - pooled_hidden = tf.zeros(shape_list(hidden)) - - if pooling_flag: - pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling( - hidden, attention_inputs - ) - - for layer_index, layer in enumerate(block): - for repeat_index in range(self.block_repeats[block_index]): - do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag - if do_pooling: - query = pooled_hidden - key = value = hidden if self.pool_q_only else pooled_hidden - else: - query = key = value = hidden - layer_output = layer( - query, key, value, attention_inputs, output_attentions=output_attentions, training=training - ) - hidden = layer_output[0] - if do_pooling: - attention_inputs = self.attention_structure.post_attention_pooling(attention_inputs) - - if output_attentions: - all_attentions = all_attentions + layer_output[1:] - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden,) - - if not return_dict: - return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None) - return TFBaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - for block in self.blocks: - for layer in block: - with tf.name_scope(layer.name): - layer.build(None) - - -def upsample(x, stride, target_len, separate_cls=True, truncate_seq=False): - """ - Upsample tensor `x` to match `target_len` by repeating the tokens `stride` time on the sequence length dimension. - """ - if stride == 1: - return x - if separate_cls: - cls = x[:, :1] - x = x[:, 1:] - output = tf.repeat(x, repeats=stride, axis=1) - if separate_cls: - if truncate_seq: - output = tf.pad(output, [[0, 0], [0, stride - 1], [0, 0]]) - output = output[:, : target_len - 1] - output = tf.concat([cls, output], axis=1) - else: - output = output[:, :target_len] - return output - - -class TFFunnelDecoder(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.separate_cls = config.separate_cls - self.truncate_seq = config.truncate_seq - self.stride = 2 ** (len(config.block_sizes) - 1) - self.attention_structure = TFFunnelAttentionStructure(config) - self.layers = [TFFunnelLayer(config, 0, name=f"layers_._{i}") for i in range(config.num_decoder_layers)] - - def call( - self, - final_hidden, - first_block_hidden, - attention_mask=None, - token_type_ids=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - training=False, - ): - upsampled_hidden = upsample( - final_hidden, - stride=self.stride, - target_len=shape_list(first_block_hidden)[1], - separate_cls=self.separate_cls, - truncate_seq=self.truncate_seq, - ) - - hidden = upsampled_hidden + first_block_hidden - all_hidden_states = (hidden,) if output_hidden_states else None - all_attentions = () if output_attentions else None - - attention_inputs = self.attention_structure.init_attention_inputs( - hidden, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - training=training, - ) - - for layer in self.layers: - layer_output = layer( - hidden, hidden, hidden, attention_inputs, output_attentions=output_attentions, training=training - ) - hidden = layer_output[0] - - if output_attentions: - all_attentions = all_attentions + layer_output[1:] - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden,) - - if not return_dict: - return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None) - return TFBaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFFunnelBaseLayer(keras.layers.Layer): - """Base model without decoder""" - - config_class = FunnelConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - - self.embeddings = TFFunnelEmbeddings(config, name="embeddings") - self.encoder = TFFunnelEncoder(config, name="encoder") - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, value): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(input_shape, 1) - - if token_type_ids is None: - token_type_ids = tf.fill(input_shape, 0) - - if inputs_embeds is None: - inputs_embeds = self.embeddings(input_ids, training=training) - - encoder_outputs = self.encoder( - inputs_embeds, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return encoder_outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - - -@keras_serializable -class TFFunnelMainLayer(keras.layers.Layer): - """Base model with decoder""" - - config_class = FunnelConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.block_sizes = config.block_sizes - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - - self.embeddings = TFFunnelEmbeddings(config, name="embeddings") - self.encoder = TFFunnelEncoder(config, name="encoder") - self.decoder = TFFunnelDecoder(config, name="decoder") - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, value): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(input_shape, 1) - - if token_type_ids is None: - token_type_ids = tf.fill(input_shape, 0) - - if inputs_embeds is None: - inputs_embeds = self.embeddings(input_ids, training=training) - - encoder_outputs = self.encoder( - inputs_embeds, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - output_attentions=output_attentions, - output_hidden_states=True, - return_dict=return_dict, - training=training, - ) - - decoder_outputs = self.decoder( - final_hidden=encoder_outputs[0], - first_block_hidden=encoder_outputs[1][self.block_sizes[0]], - attention_mask=attention_mask, - token_type_ids=token_type_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - idx = 0 - outputs = (decoder_outputs[0],) - if output_hidden_states: - idx += 1 - outputs = outputs + (encoder_outputs[1] + decoder_outputs[idx],) - if output_attentions: - idx += 1 - outputs = outputs + (encoder_outputs[2] + decoder_outputs[idx],) - return outputs - - return TFBaseModelOutput( - last_hidden_state=decoder_outputs[0], - hidden_states=(encoder_outputs.hidden_states + decoder_outputs.hidden_states) - if output_hidden_states - else None, - attentions=(encoder_outputs.attentions + decoder_outputs.attentions) if output_attentions else None, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -class TFFunnelDiscriminatorPredictions(keras.layers.Layer): - """Prediction module for the discriminator, made up of two dense layers.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - initializer = get_initializer(config.initializer_range) - self.dense = keras.layers.Dense(config.d_model, kernel_initializer=initializer, name="dense") - self.activation_function = get_tf_activation(config.hidden_act) - self.dense_prediction = keras.layers.Dense(1, kernel_initializer=initializer, name="dense_prediction") - self.config = config - - def call(self, discriminator_hidden_states): - hidden_states = self.dense(discriminator_hidden_states) - hidden_states = self.activation_function(hidden_states) - logits = tf.squeeze(self.dense_prediction(hidden_states)) - return logits - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.d_model]) - if getattr(self, "dense_prediction", None) is not None: - with tf.name_scope(self.dense_prediction.name): - self.dense_prediction.build([None, None, self.config.d_model]) - - -class TFFunnelMaskedLMHead(keras.layers.Layer): - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - self.config = config - self.hidden_size = config.hidden_size - self.input_embeddings = input_embeddings - - def build(self, input_shape): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - super().build(input_shape) - - def get_output_embeddings(self): - return self.input_embeddings - - def set_output_embeddings(self, value): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states, training=False): - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -class TFFunnelClassificationHead(keras.layers.Layer): - def __init__(self, config, n_labels, **kwargs): - super().__init__(**kwargs) - initializer = get_initializer(config.initializer_range) - self.linear_hidden = keras.layers.Dense(config.d_model, kernel_initializer=initializer, name="linear_hidden") - self.dropout = keras.layers.Dropout(config.hidden_dropout) - self.linear_out = keras.layers.Dense(n_labels, kernel_initializer=initializer, name="linear_out") - self.config = config - - def call(self, hidden, training=False): - hidden = self.linear_hidden(hidden) - hidden = keras.activations.tanh(hidden) - hidden = self.dropout(hidden, training=training) - return self.linear_out(hidden) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "linear_hidden", None) is not None: - with tf.name_scope(self.linear_hidden.name): - self.linear_hidden.build([None, None, self.config.d_model]) - if getattr(self, "linear_out", None) is not None: - with tf.name_scope(self.linear_out.name): - self.linear_out.build([None, None, self.config.d_model]) - - -class TFFunnelPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = FunnelConfig - base_model_prefix = "funnel" - - @property - def dummy_inputs(self): - # Funnel misbehaves with very small inputs, so we override and make them a bit bigger - return {"input_ids": tf.ones((1, 3), dtype=tf.int32)} - - -@dataclass -class TFFunnelForPreTrainingOutput(ModelOutput): - """ - Output type of [`FunnelForPreTraining`]. - - Args: - logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Prediction scores of the head (scores for each token before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -FUNNEL_START_DOCSTRING = r""" - - The Funnel Transformer model was proposed in [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient - Language Processing](https://huggingface.co/papers/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`XxxConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -FUNNEL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - """ - The base Funnel Transformer Model transformer outputting raw hidden-states without upsampling head (also called - decoder) or any task-specific head on top. - """, - FUNNEL_START_DOCSTRING, -) -class TFFunnelBaseModel(TFFunnelPreTrainedModel): - def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None: - super().__init__(config, *inputs, **kwargs) - self.funnel = TFFunnelBaseLayer(config, name="funnel") - - @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="funnel-transformer/small-base", - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFBaseModelOutput: - return self.funnel( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - def serving_output(self, output): - # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of - # different dimensions - return TFBaseModelOutput( - last_hidden_state=output.last_hidden_state, - hidden_states=output.hidden_states, - attentions=output.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "funnel", None) is not None: - with tf.name_scope(self.funnel.name): - self.funnel.build(None) - - -@add_start_docstrings( - "The bare Funnel Transformer Model transformer outputting raw hidden-states without any specific head on top.", - FUNNEL_START_DOCSTRING, -) -class TFFunnelModel(TFFunnelPreTrainedModel): - def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None: - super().__init__(config, *inputs, **kwargs) - self.funnel = TFFunnelMainLayer(config, name="funnel") - - @unpack_inputs - @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="funnel-transformer/small", - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFBaseModelOutput: - return self.funnel( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - def serving_output(self, output): - # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of - # different dimensions - return TFBaseModelOutput( - last_hidden_state=output.last_hidden_state, - hidden_states=output.hidden_states, - attentions=output.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "funnel", None) is not None: - with tf.name_scope(self.funnel.name): - self.funnel.build(None) - - -@add_start_docstrings( - """ - Funnel model with a binary classification head on top as used during pretraining for identifying generated tokens. - """, - FUNNEL_START_DOCSTRING, -) -class TFFunnelForPreTraining(TFFunnelPreTrainedModel): - def __init__(self, config: FunnelConfig, **kwargs) -> None: - super().__init__(config, **kwargs) - - self.funnel = TFFunnelMainLayer(config, name="funnel") - self.discriminator_predictions = TFFunnelDiscriminatorPredictions(config, name="discriminator_predictions") - - @unpack_inputs - @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFFunnelForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, - ) -> tuple[tf.Tensor] | TFFunnelForPreTrainingOutput: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TFFunnelForPreTraining - >>> import torch - from ...utils.deprecation import deprecate_kwarg - from ...utils.deprecation import deprecate_kwarg - from ...utils.deprecation import deprecate_kwarg - from ...utils.deprecation import deprecate_kwarg - - >>> tokenizer = AutoTokenizer.from_pretrained("funnel-transformer/small") - >>> model = TFFunnelForPreTraining.from_pretrained("funnel-transformer/small") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") - >>> logits = model(inputs).logits - ```""" - discriminator_hidden_states = self.funnel( - input_ids, - attention_mask, - token_type_ids, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - discriminator_sequence_output = discriminator_hidden_states[0] - logits = self.discriminator_predictions(discriminator_sequence_output) - - if not return_dict: - return (logits,) + discriminator_hidden_states[1:] - - return TFFunnelForPreTrainingOutput( - logits=logits, - hidden_states=discriminator_hidden_states.hidden_states, - attentions=discriminator_hidden_states.attentions, - ) - - def serving_output(self, output): - # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of - # different dimensions - return TFFunnelForPreTrainingOutput( - logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "funnel", None) is not None: - with tf.name_scope(self.funnel.name): - self.funnel.build(None) - if getattr(self, "discriminator_predictions", None) is not None: - with tf.name_scope(self.discriminator_predictions.name): - self.discriminator_predictions.build(None) - - -@add_start_docstrings("""Funnel Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING) -class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss): - def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None: - super().__init__(config, *inputs, **kwargs) - - self.funnel = TFFunnelMainLayer(config, name="funnel") - self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head") - - def get_lm_head(self) -> TFFunnelMaskedLMHead: - return self.lm_head - - def get_prefix_bias_name(self) -> str: - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="funnel-transformer/small", - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFMaskedLMOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - outputs = self.funnel( - input_ids, - attention_mask, - token_type_ids, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output, training=training) - - loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput: - # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of - # different dimensions - return TFMaskedLMOutput(logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "funnel", None) is not None: - with tf.name_scope(self.funnel.name): - self.funnel.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -@add_start_docstrings( - """ - Funnel Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled - output) e.g. for GLUE tasks. - """, - FUNNEL_START_DOCSTRING, -) -class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None: - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.funnel = TFFunnelBaseLayer(config, name="funnel") - self.classifier = TFFunnelClassificationHead(config, config.num_labels, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="funnel-transformer/small-base", - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFSequenceClassifierOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.funnel( - input_ids, - attention_mask, - token_type_ids, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - last_hidden_state = outputs[0] - pooled_output = last_hidden_state[:, 0] - logits = self.classifier(pooled_output, training=training) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput: - # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of - # different dimensions - return TFSequenceClassifierOutput( - logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "funnel", None) is not None: - with tf.name_scope(self.funnel.name): - self.funnel.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - Funnel Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - FUNNEL_START_DOCSTRING, -) -class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss): - def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None: - super().__init__(config, *inputs, **kwargs) - - self.funnel = TFFunnelBaseLayer(config, name="funnel") - self.classifier = TFFunnelClassificationHead(config, 1, name="classifier") - - @property - def dummy_inputs(self): - return {"input_ids": tf.ones((3, 3, 4), dtype=tf.int32)} - - @unpack_inputs - @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @add_code_sample_docstrings( - checkpoint="funnel-transformer/small-base", - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFMultipleChoiceModelOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_inputs_embeds = ( - tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - - outputs = self.funnel( - flat_input_ids, - attention_mask=flat_attention_mask, - token_type_ids=flat_token_type_ids, - inputs_embeds=flat_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - last_hidden_state = outputs[0] - pooled_output = last_hidden_state[:, 0] - logits = self.classifier(pooled_output, training=training) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput: - # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of - # different dimensions - return TFMultipleChoiceModelOutput( - logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "funnel", None) is not None: - with tf.name_scope(self.funnel.name): - self.funnel.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - Funnel Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - FUNNEL_START_DOCSTRING, -) -class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificationLoss): - def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None: - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.funnel = TFFunnelMainLayer(config, name="funnel") - self.dropout = keras.layers.Dropout(config.hidden_dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="funnel-transformer/small", - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFTokenClassifierOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.funnel( - input_ids, - attention_mask, - token_type_ids, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput: - # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of - # different dimensions - return TFTokenClassifierOutput( - logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "funnel", None) is not None: - with tf.name_scope(self.funnel.name): - self.funnel.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Funnel Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - FUNNEL_START_DOCSTRING, -) -class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringLoss): - def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None: - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.funnel = TFFunnelMainLayer(config, name="funnel") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="funnel-transformer/small", - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFQuestionAnsweringModelOutput: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - - outputs = self.funnel( - input_ids, - attention_mask, - token_type_ids, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - - loss = None - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions, "end_position": end_positions} - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput: - # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of - # different dimensions - return TFQuestionAnsweringModelOutput( - start_logits=output.start_logits, - end_logits=output.end_logits, - hidden_states=output.hidden_states, - attentions=output.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "funnel", None) is not None: - with tf.name_scope(self.funnel.name): - self.funnel.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFFunnelBaseModel", - "TFFunnelForMaskedLM", - "TFFunnelForMultipleChoice", - "TFFunnelForPreTraining", - "TFFunnelForQuestionAnswering", - "TFFunnelForSequenceClassification", - "TFFunnelForTokenClassification", - "TFFunnelModel", - "TFFunnelPreTrainedModel", -] diff --git a/src/transformers/models/gemma/modeling_flax_gemma.py b/src/transformers/models/gemma/modeling_flax_gemma.py deleted file mode 100644 index 0addcd7dde7a..000000000000 --- a/src/transformers/models/gemma/modeling_flax_gemma.py +++ /dev/null @@ -1,777 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google Inc., and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax Gemma model.""" - -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_gemma import GemmaConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "GemmaConfig" -_CHECKPOINT_FOR_DOC = "google/gemma-2b" -_REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_llama_3b_v2" - -GEMMA_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`GemmaConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or - `jax.numpy.bfloat16`. - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -GEMMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def create_sinusoidal_positions(num_pos, dim): - inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2)[: (dim // 2)] / dim)) - freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") - - emb = np.concatenate((freqs, freqs), axis=-1) - out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) - return jnp.array(out[:, :, :num_pos]) - - -# Copied from transformers.models.llama.modeling_flax_llama.rotate_half -def rotate_half(tensor): - """Rotates half the hidden dims of the input.""" - rotate_half_tensor = jnp.concatenate( - (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1 - ) - return rotate_half_tensor - - -# Copied from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(tensor, sin_pos, cos_pos): - return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos) - - -class FlaxGemmaRMSNorm(nn.Module): - config: GemmaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.epsilon = self.config.rms_norm_eps - self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size) - - def __call__(self, hidden_states): - variance = jnp.asarray(hidden_states, dtype=jnp.float32) - variance = jnp.power(variance, 2) - variance = variance.mean(-1, keepdims=True) - # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt` - hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon) - - return (1 + self.weight) * jnp.asarray(hidden_states, dtype=self.dtype) - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Gemma -class FlaxGemmaRotaryEmbedding(nn.Module): - config: GemmaConfig - dtype: jnp.dtype = jnp.float32 - - # Ignore copy - def setup(self): - head_dim = self.config.head_dim - self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim) - - def __call__(self, key, query, position_ids): - sincos = self.sincos[position_ids] - sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1) - - key = apply_rotary_pos_emb(key, sin_pos, cos_pos) - query = apply_rotary_pos_emb(query, sin_pos, cos_pos) - - key = jnp.asarray(key, dtype=self.dtype) - query = jnp.asarray(query, dtype=self.dtype) - - return key, query - - -class FlaxGemmaAttention(nn.Module): - config: GemmaConfig - dtype: jnp.dtype = jnp.float32 - causal: bool = True - is_cross_attention: bool = False - - def setup(self): - config = self.config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 - - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - - kernel = jax.nn.initializers.normal(self.config.initializer_range) - self.q_proj = nn.Dense( - self.num_heads * self.head_dim, use_bias=config.attention_bias, dtype=self.dtype, kernel_init=kernel - ) - self.k_proj = nn.Dense( - self.num_key_value_heads * self.head_dim, - use_bias=config.attention_bias, - dtype=self.dtype, - kernel_init=kernel, - ) - self.v_proj = nn.Dense( - self.num_key_value_heads * self.head_dim, - use_bias=config.attention_bias, - dtype=self.dtype, - kernel_init=kernel, - ) - self.o_proj = nn.Dense(self.embed_dim, use_bias=config.attention_bias, dtype=self.dtype, kernel_init=kernel) - - self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") - self.rotary_emb = FlaxGemmaRotaryEmbedding(config, dtype=self.dtype) - - def _split_heads(self, hidden_states, num_heads): - return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads * self.head_dim,)) - - @nn.compact - # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - attention_mask, - position_ids, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - query = self._split_heads(query, self.num_heads) - key = self._split_heads(key, self.num_key_value_heads) - value = self._split_heads(value, self.num_key_value_heads) - - key, query = self.rotary_emb(key, query, position_ids) - - query_length, key_length = query.shape[1], key.shape[1] - - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - - batch_size = hidden_states.shape[0] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - - dropout_rng = None - if not deterministic and self.config.attention_dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.has_variable("cache", "cached_key") or init_cache: - key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) - - # transform boolean mask into float mask - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - - key = jnp.repeat(key, repeats=self.num_key_value_groups, axis=2) - value = jnp.repeat(value, repeats=self.num_key_value_groups, axis=2) - - # usual dot product attention - attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype - attn_weights = dot_product_attention_weights( - query, - key, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_dropout, - deterministic=deterministic, - dtype=attention_dtype, - ) - - if self.attention_softmax_in_fp32: - attn_weights = attn_weights.astype(self.dtype) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) - attn_output = self._merge_heads(attn_output) - attn_output = self.o_proj(attn_output) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxGemmaMLP(nn.Module): - config: GemmaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - embed_dim = self.config.hidden_size - inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim - - kernel_init = jax.nn.initializers.normal(self.config.initializer_range) - if self.config.hidden_activation is None: - logger.warning_once( - "Gemma's activation function should be approximate GeLU and not exact GeLU. " - "Changing the activation function to `gelu_pytorch_tanh`." - f"if you want to use the legacy `{self.config.hidden_act}`, " - f"edit the `model.config` to set `hidden_activation={self.config.hidden_act}` " - " instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details." - ) - hidden_activation = "gelu_pytorch_tanh" - else: - hidden_activation = self.config.hidden_activation - self.act = ACT2FN[hidden_activation] - - self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) - self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) - self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) - - def __call__(self, hidden_states): - up_proj_states = self.up_proj(hidden_states) - gate_states = self.act(self.gate_proj(hidden_states)) - - hidden_states = self.down_proj(up_proj_states * gate_states) - return hidden_states - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Gemma -class FlaxGemmaDecoderLayer(nn.Module): - config: GemmaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.input_layernorm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype) - self.self_attn = FlaxGemmaAttention(self.config, dtype=self.dtype) - self.post_attention_layernorm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype) - self.mlp = FlaxGemmaMLP(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - outputs = self.self_attn( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - # residual connection - attn_output = outputs[0] - hidden_states = residual + attn_output - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - # residual connection - hidden_states = residual + hidden_states - - return (hidden_states,) + outputs[1:] - - -# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Gemma, GPT_NEO->GEMMA, transformer->model -class FlaxGemmaPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = GemmaConfig - base_model_prefix = "model" - module_class: nn.Module = None - - def __init__( - self, - config: GemmaConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length)) - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - params: Optional[dict] = None, - past_key_values: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - batch_size, sequence_length = input_ids.shape - - if position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGemmaAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - not train, - False, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection with Llama->Gemma -class FlaxGemmaLayerCollection(nn.Module): - config: GemmaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.blocks = [ - FlaxGemmaDecoderLayer(self.config, dtype=self.dtype, name=str(i)) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = False, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for block in self.blocks: - if output_hidden_states: - all_hidden_states += (hidden_states,) - layer_outputs = block( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - # this contains possible `None` values - `FlaxGemmaModule` will filter them out - outputs = (hidden_states, all_hidden_states, all_attentions) - - return outputs - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModule with Llama->Gemma -class FlaxGemmaModule(nn.Module): - config: GemmaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.hidden_size = self.config.hidden_size - embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) - self.embed_tokens = nn.Embed( - self.config.vocab_size, - self.hidden_size, - embedding_init=embedding_init, - dtype=self.dtype, - ) - self.layers = FlaxGemmaLayerCollection(self.config, dtype=self.dtype) - self.norm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype) - - # Ignore copy - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - deterministic=True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - input_embeds = self.embed_tokens(input_ids.astype("i4")) - - input_embeds = input_embeds * (self.config.hidden_size**0.5) - - outputs = self.layers( - input_embeds, - position_ids=position_ids, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.norm(hidden_states) - - if output_hidden_states: - all_hidden_states = outputs[1] + (hidden_states,) - outputs = (hidden_states, all_hidden_states) + outputs[2:] - else: - outputs = (hidden_states,) + outputs[1:] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=outputs[1], - attentions=outputs[-1], - ) - - -@add_start_docstrings( - "The bare Gemma Model transformer outputting raw hidden-states without any specific head on top.", - GEMMA_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModel with Llama->Gemma -class FlaxGemmaModel(FlaxGemmaPreTrainedModel): - module_class = FlaxGemmaModule - - -append_call_sample_docstring( - FlaxGemmaModel, - _CHECKPOINT_FOR_DOC, - FlaxBaseModelOutput, - _CONFIG_FOR_DOC, - real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, -) - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule with Llama->Gemma -class FlaxGemmaForCausalLMModule(nn.Module): - config: GemmaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.model = FlaxGemmaModule(self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - - # Ignore copy - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - outputs = self.model( - input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_kernel = self.model.variables["params"]["embed_tokens"]["embedding"].T - lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - - -@add_start_docstrings( - """ - The Gemma Model transformer with a language modeling head (linear layer) on top. - """, - GEMMA_START_DOCSTRING, -) -# Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Gemma -class FlaxGemmaForCausalLM(FlaxGemmaPreTrainedModel): - module_class = FlaxGemmaForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since Gemma uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxGemmaForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutput, - _CONFIG_FOR_DOC, - real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, -) - - -__all__ = ["FlaxGemmaForCausalLM", "FlaxGemmaModel", "FlaxGemmaPreTrainedModel"] diff --git a/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py deleted file mode 100755 index 33f9dabed07f..000000000000 --- a/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,68 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert OpenAI GPT checkpoint.""" - -import argparse - -import torch - -from transformers import GPT2Config, GPT2Model, load_tf_weights_in_gpt2 -from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging - - -logging.set_verbosity_info() - - -def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): - # Construct model - if gpt2_config_file == "": - config = GPT2Config() - else: - config = GPT2Config.from_json_file(gpt2_config_file) - model = GPT2Model(config) - - # Load weights from numpy - load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) - - # Save pytorch-model - pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME - pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME - print(f"Save PyTorch model to {pytorch_weights_dump_path}") - torch.save(model.state_dict(), pytorch_weights_dump_path) - print(f"Save configuration file to {pytorch_config_dump_path}") - with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: - f.write(config.to_json_string()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - parser.add_argument( - "--gpt2_config_file", - default="", - type=str, - help=( - "An optional config json file corresponding to the pre-trained OpenAI model. \n" - "This specifies the model architecture." - ), - ) - args = parser.parse_args() - convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/gpt2/modeling_flax_gpt2.py b/src/transformers/models/gpt2/modeling_flax_gpt2.py deleted file mode 100644 index 8e419217c5a3..000000000000 --- a/src/transformers/models/gpt2/modeling_flax_gpt2.py +++ /dev/null @@ -1,782 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, -) -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_gpt2 import GPT2Config - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "openai-community/gpt2" -_CONFIG_FOR_DOC = "GPT2Config" - - -GPT2_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`GPT2Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -GPT2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class FlaxConv1D(nn.Module): - features: int - use_bias: bool = True - dtype: Any = jnp.float32 - precision: Any = None - - @nn.compact - def __call__(self, inputs): - inputs = jnp.asarray(inputs, self.dtype) - kernel = self.param("kernel", jax.nn.initializers.normal(stddev=0.02), (self.features, inputs.shape[-1])) - kernel = jnp.asarray(kernel.transpose(), self.dtype) - y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision) - if self.use_bias: - bias = self.param("bias", jax.nn.initializers.zeros, (self.features,)) - bias = jnp.asarray(bias, self.dtype) - y = y + bias - return y - - -class FlaxGPT2Attention(nn.Module): - config: GPT2Config - dtype: jnp.dtype = jnp.float32 - causal: bool = True - is_cross_attention: bool = False - - def setup(self): - config = self.config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - - if self.is_cross_attention: - self.c_attn = FlaxConv1D(2 * self.embed_dim, dtype=self.dtype) - self.q_attn = FlaxConv1D(self.embed_dim, dtype=self.dtype) - else: - self.c_attn = FlaxConv1D(3 * self.embed_dim, dtype=self.dtype) - self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype) - - self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - if not is_cross_attention: - qkv_out = self.c_attn(hidden_states) - query, key, value = jnp.split(qkv_out, 3, axis=2) - else: - q_out = self.q_attn(hidden_states) - (query,) = jnp.split(q_out, 1, axis=2) - kv_out = self.c_attn(key_value_states) - key, value = jnp.split(kv_out, 2, axis=2) - - query = self._split_heads(query) - key = self._split_heads(key) - value = self._split_heads(value) - - query_length, key_length = query.shape[1], key.shape[1] - - if self.causal: - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - dropout_rng = None - if not deterministic and self.config.attn_pdrop > 0.0: - dropout_rng = self.make_rng("dropout") - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) - - # transform boolean mask into float mask - if attention_mask is not None: - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - # usual dot product attention - attn_weights = dot_product_attention_weights( - query, - key, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attn_pdrop, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) - attn_output = self._merge_heads(attn_output) - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output, deterministic=deterministic) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxGPT2MLP(nn.Module): - config: GPT2Config - intermediate_size: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - embed_dim = self.config.hidden_size - self.c_fc = FlaxConv1D(self.intermediate_size, dtype=self.dtype) - self.c_proj = FlaxConv1D(embed_dim, dtype=self.dtype) - self.act = ACT2FN[self.config.activation_function] - self.dropout = nn.Dropout(rate=self.config.resid_pdrop) - - def __call__(self, hidden_states, deterministic: bool = True): - hidden_states = self.c_fc(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxGPT2Block(nn.Module): - config: GPT2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - hidden_size = self.config.hidden_size - inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size - - self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype) - self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - - if self.config.add_cross_attention: - self.crossattention = FlaxGPT2Attention( - config=self.config, dtype=self.dtype, causal=False, is_cross_attention=True - ) - self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - - self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask=None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( - hidden_states, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - # residual connection - attn_output = attn_outputs[0] # output_attn: a, (attentions) - outputs = attn_outputs[1:] - # residual connection - hidden_states = attn_output + residual - - # Cross-Attention Block - if encoder_hidden_states is not None: - # add one self-attention block for cross-attention - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " - "cross-attention layers by setting `config.add_cross_attention=True`" - ) - residual = hidden_states - hidden_states = self.ln_cross_attn(hidden_states) - cross_attn_outputs = self.crossattention( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attn_output = cross_attn_outputs[0] - # residual connection - hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights - - residual = hidden_states - hidden_states = self.ln_2(hidden_states) - feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic) - # residual connection - hidden_states = residual + feed_forward_hidden_states - - outputs = (hidden_states,) + outputs - - return outputs - - -class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = GPT2Config - base_model_prefix = "transformer" - module_class: nn.Module = None - - def __init__( - self, - config: GPT2Config, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) - - random_params = module_init_outputs["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length)) - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - params: Optional[dict] = None, - past_key_values: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if encoder_hidden_states is not None and encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = input_ids.shape - - if position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - encoder_hidden_states, - encoder_attention_mask, - not train, - False, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -class FlaxGPT2BlockCollection(nn.Module): - config: GPT2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.blocks = [ - FlaxGPT2Block(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask=None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - for block in self.blocks: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = block( - hidden_states, - attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # this contains possible `None` values - `FlaxGPT2Module` will filter them out - outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) - - return outputs - - -class FlaxGPT2Module(nn.Module): - config: GPT2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.embed_dim = self.config.hidden_size - - self.wte = nn.Embed( - self.config.vocab_size, - self.embed_dim, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.wpe = nn.Embed( - self.config.max_position_embeddings, - self.embed_dim, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.embd_pdrop) - self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype) - self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic=True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - input_embeds = self.wte(input_ids.astype("i4")) - position_embeds = self.wpe(position_ids.astype("i4")) - - hidden_states = input_embeds + position_embeds - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - - outputs = self.h( - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.ln_f(hidden_states) - - if output_hidden_states: - all_hidden_states = outputs[1] + (hidden_states,) - outputs = (hidden_states, all_hidden_states) + outputs[2:] - else: - outputs = (hidden_states,) + outputs[1:] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=outputs[1], - attentions=outputs[2], - cross_attentions=outputs[3], - ) - - -@add_start_docstrings( - "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", - GPT2_START_DOCSTRING, -) -class FlaxGPT2Model(FlaxGPT2PreTrainedModel): - module_class = FlaxGPT2Module - - -append_call_sample_docstring( - FlaxGPT2Model, - _CHECKPOINT_FOR_DOC, - FlaxBaseModelOutputWithPastAndCrossAttentions, - _CONFIG_FOR_DOC, -) - - -class FlaxGPT2LMHeadModule(nn.Module): - config: GPT2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.transformer = FlaxGPT2Module(self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - outputs = self.transformer( - input_ids, - attention_mask, - position_ids, - encoder_hidden_states, - encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T - lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - """ - The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - GPT2_START_DOCSTRING, -) -class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel): - module_class = FlaxGPT2LMHeadModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since GPT2 uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice( - extended_attention_mask, attention_mask.astype("i4"), (0, 0) - ) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxGPT2LMHeadModel, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutputWithCrossAttentions, - _CONFIG_FOR_DOC, -) - - -__all__ = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"] diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py deleted file mode 100644 index 42e23fc29015..000000000000 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ /dev/null @@ -1,1238 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 OpenAI GPT-2 model.""" - -from __future__ import annotations - -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFCausalLMOutputWithCrossAttentions, - TFSequenceClassifierOutputWithPast, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFConv1D, - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - TFSequenceSummary, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_gpt2 import GPT2Config - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "openai-community/gpt2" -_CONFIG_FOR_DOC = "GPT2Config" - - -class TFAttention(keras.layers.Layer): - def __init__(self, nx, config, scale=False, is_cross_attention=False, **kwargs): - super().__init__(**kwargs) - - n_state = nx # in Attention: n_state=768 (nx=n_embd) - # [switch nx => n_state from Block to Attention to keep identical to TF implementation] - assert n_state % config.n_head == 0 - self.n_head = config.n_head - self.split_size = n_state - self.scale = scale - self.output_attentions = config.output_attentions - - self.is_cross_attention = is_cross_attention - - if self.is_cross_attention: - self.c_attn = TFConv1D(n_state * 2, nx, initializer_range=config.initializer_range, name="c_attn") - self.q_attn = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="q_attn") - else: - self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn") - - self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj") - self.attn_dropout = keras.layers.Dropout(config.attn_pdrop) - self.resid_dropout = keras.layers.Dropout(config.resid_pdrop) - self.pruned_heads = set() - self.embed_dim = n_state - - def prune_heads(self, heads): - pass - - @staticmethod - def causal_attention_mask(nd, ns, dtype): - """ - 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), - -1, ns-nd), but doesn't produce garbage on TPUs. - """ - i = tf.range(nd)[:, None] - j = tf.range(ns) - m = i >= j - ns + nd - return tf.cast(m, dtype) - - def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False): - # q, k, v have shape [batch, heads, sequence, features] - w = tf.matmul(q, k, transpose_b=True) - if self.scale: - dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores - w = w / tf.math.sqrt(dk) - - if not self.is_cross_attention: - # if only "normal" attention layer implements causal mask - - # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. - _, _, nd, ns = shape_list(w) - b = self.causal_attention_mask(nd, ns, dtype=w.dtype) - b = tf.reshape(b, [1, 1, nd, ns]) - w = w * b - 1e4 * (1 - b) - - if attention_mask is not None: - # Apply the attention mask - attention_mask = tf.cast(attention_mask, dtype=w.dtype) - w = w + attention_mask - - w = stable_softmax(w, axis=-1) - w = self.attn_dropout(w, training=training) - - # Mask heads if we want to - if head_mask is not None: - w = w * head_mask - - outputs = [tf.matmul(w, v)] - if output_attentions: - outputs.append(w) - return outputs - - def merge_heads(self, x): - x = tf.transpose(x, [0, 2, 1, 3]) - x_shape = shape_list(x) - new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]] - return tf.reshape(x, new_x_shape) - - def split_heads(self, x): - x_shape = shape_list(x) - new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head] - x = tf.reshape(x, new_x_shape) - return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) - - def call( - self, - x, - layer_past, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - training=False, - ): - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn"): - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(x) - kv_out = self.c_attn(encoder_hidden_states) - key, value = tf.split(kv_out, 2, axis=2) - attention_mask = encoder_attention_mask - else: - x = self.c_attn(x) - query, key, value = tf.split(x, 3, axis=2) - - query = self.split_heads(query) - key = self.split_heads(key) - value = self.split_heads(value) - if layer_past is not None: - past_key, past_value = tf.unstack(layer_past, axis=0, num=2) - key = tf.concat([past_key, key], axis=-2) - value = tf.concat([past_value, value], axis=-2) - - # to cope with keras serialization - if use_cache: - present = tf.stack([key, value], axis=0) - else: - present = (None,) - - attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training) - a = attn_outputs[0] - - a = self.merge_heads(a) - a = self.c_proj(a) - a = self.resid_dropout(a, training=training) - - outputs = [a, present] + attn_outputs[1:] - return outputs # a, present, (attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if self.is_cross_attention: - c_attn_shape = 2 * self.embed_dim - else: - c_attn_shape = 3 * self.embed_dim - if getattr(self, "c_proj", None) is not None: - with tf.name_scope(self.c_proj.name): - self.c_proj.build([None, None, self.embed_dim]) - if getattr(self, "c_attn", None) is not None: - with tf.name_scope(self.c_attn.name): - self.c_attn.build([None, None, c_attn_shape]) - if getattr(self, "q_attn", None) is not None: - with tf.name_scope(self.q_attn.name): - self.q_attn.build([None, None, self.embed_dim]) - - -class TFMLP(keras.layers.Layer): - def __init__(self, n_state, config, **kwargs): - super().__init__(**kwargs) - nx = config.n_embd - self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc") - self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj") - self.act = get_tf_activation(config.activation_function) - self.dropout = keras.layers.Dropout(config.resid_pdrop) - self.intermediate_size = n_state - self.embed_dim = nx - - def call(self, x, training=False): - h = self.act(self.c_fc(x)) - h2 = self.c_proj(h) - h2 = self.dropout(h2, training=training) - return h2 - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "c_fc", None) is not None: - with tf.name_scope(self.c_fc.name): - self.c_fc.build([None, None, self.intermediate_size]) - if getattr(self, "c_proj", None) is not None: - with tf.name_scope(self.c_proj.name): - self.c_proj.build([None, None, self.embed_dim]) - - -class TFBlock(keras.layers.Layer): - def __init__(self, config, scale=False, **kwargs): - super().__init__(**kwargs) - nx = config.n_embd - inner_dim = config.n_inner if config.n_inner is not None else 4 * nx - self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1") - self.attn = TFAttention(nx, config, scale, name="attn") - self.ln_2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2") - - if config.add_cross_attention: - self.crossattention = TFAttention(nx, config, scale, name="crossattention", is_cross_attention=True) - self.ln_cross_attn = keras.layers.LayerNormalization( - epsilon=config.layer_norm_epsilon, name="ln_cross_attn" - ) - - self.mlp = TFMLP(inner_dim, config, name="mlp") - self.hidden_size = config.hidden_size - - def call( - self, - x, - layer_past, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - training=False, - ): - a = self.ln_1(x) - output_attn = self.attn( - a, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - a = output_attn[0] # output_attn: a, present, (attentions) - outputs = output_attn[1:] - x = x + a - - # Cross-Attention Block - if encoder_hidden_states is not None: - # add one self-attention block for cross-attention - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " - "cross-attention layers by setting `config.add_cross_attention=True`" - ) - - ca = self.ln_cross_attn(x) - output_cross_attn = self.crossattention( - ca, - layer_past=None, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=False, - output_attentions=output_attentions, - training=training, - ) - ca = output_cross_attn[0] # output_attn: a, present, (cross_attentions) - x = x + ca - outputs = outputs + output_cross_attn[2:] # add cross attentions if we output attention weights - - m = self.ln_2(x) - m = self.mlp(m, training=training) - x = x + m - - outputs = [x] + outputs - return outputs # x, present, (attentions, cross_attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "ln_1", None) is not None: - with tf.name_scope(self.ln_1.name): - self.ln_1.build([None, None, self.hidden_size]) - if getattr(self, "attn", None) is not None: - with tf.name_scope(self.attn.name): - self.attn.build(None) - if getattr(self, "ln_2", None) is not None: - with tf.name_scope(self.ln_2.name): - self.ln_2.build([None, None, self.hidden_size]) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "crossattention", None) is not None: - with tf.name_scope(self.crossattention.name): - self.crossattention.build(None) - if getattr(self, "ln_cross_attn", None) is not None: - with tf.name_scope(self.ln_cross_attn.name): - self.ln_cross_attn.build([None, None, self.hidden_size]) - - -@keras_serializable -class TFGPT2MainLayer(keras.layers.Layer): - config_class = GPT2Config - - def __init__(self, config, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - self.config = config - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.use_cache = config.use_cache - self.return_dict = config.use_return_dict - - self.num_hidden_layers = config.n_layer - self.n_embd = config.n_embd - self.n_positions = config.n_positions - self.initializer_range = config.initializer_range - - self.wte = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.hidden_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="wte", - ) - self.wpe = keras.layers.Embedding( - input_dim=config.n_positions, - output_dim=config.n_embd, - embeddings_initializer=get_initializer(config.initializer_range), - name="wpe", - ) - self.drop = keras.layers.Dropout(config.embd_pdrop) - self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)] - self.ln_f = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f") - self.embed_dim = config.hidden_size - - def get_input_embeddings(self): - return self.wte - - def set_input_embeddings(self, new_embeddings): - self.wte = new_embeddings - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if past_key_values is None: - past_length = 0 - past_key_values = [None] * len(self.h) - else: - past_length = shape_list(past_key_values[0][0])[-2] - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) - - if attention_mask is not None: - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(attention_mask) - attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - one_cst = tf.constant(1.0) - attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype) - attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0)) - - # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.config.add_cross_attention and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=encoder_hidden_states.dtype) - num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) - if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, - # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) - - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - else: - encoder_extended_attention_mask = None - - encoder_attention_mask = encoder_extended_attention_mask - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.num_hidden_layers - # head_mask = tf.constant([0] * self.num_hidden_layers) - - position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = self.wte(input_ids) - - position_embeds = self.wpe(position_ids) - - if token_type_ids is not None: - token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) - token_type_embeds = self.wte(token_type_ids) - else: - token_type_embeds = tf.constant(0.0) - - position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype) - token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype) - hidden_states = inputs_embeds + position_embeds + token_type_embeds - hidden_states = self.drop(hidden_states, training=training) - - output_shape = input_shape + [shape_list(hidden_states)[-1]] - - presents = () if use_cache else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) - - outputs = block( - hidden_states, - layer_past, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - training=training, - ) - - hidden_states, present = outputs[:2] - if use_cache: - presents = presents + (present,) - - if output_attentions: - all_attentions = all_attentions + (outputs[2],) - if self.config.add_cross_attention and encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (outputs[3],) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = tf.reshape(hidden_states, output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if output_attentions: - # let the number of heads free (-1) so we can extract attention even after head pruning - attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] - all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) - - if not return_dict: - return tuple( - v - for v in [hidden_states, presents, all_hidden_states, all_attentions, all_cross_attentions] - if v is not None - ) - - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "wte", None) is not None: - with tf.name_scope(self.wte.name): - self.wte.build(None) - if getattr(self, "wpe", None) is not None: - with tf.name_scope(self.wpe.name): - self.wpe.build(None) - if getattr(self, "ln_f", None) is not None: - with tf.name_scope(self.ln_f.name): - self.ln_f.build([None, None, self.embed_dim]) - if getattr(self, "h", None) is not None: - for layer in self.h: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFGPT2PreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = GPT2Config - base_model_prefix = "transformer" - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias", r"h.\d+.crossattention.bias"] - - @property - def input_signature(self): - # Although GPT-2 supports token_type_ids in theory, in practice they are rarely used, and the implementation - # means that passing token_type_ids=0 yields different outputs from token_type_ids=None. - # Therefore, we remove the token_type_ids argument by default, even though it would usually be included. - return { - "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), - "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), - } - - -@dataclass -class TFGPT2DoubleHeadsModelOutput(ModelOutput): - """ - Base class for outputs of models predicting if two sentences are consecutive or not. - - Args: - logits (`tf.Tensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - mc_logits (`tf.Tensor` of shape `(batch_size, num_choices)`): - Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: tf.Tensor | None = None - mc_logits: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -GPT2_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`GPT2Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -GPT2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]` - (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. - - If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as - `input_ids`. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - past_key_values (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see - `past_key_values` output below). Can be used to speed up sequential decoding. The token ids which have - their past given to this model should not be passed as input ids as they have already been computed. - attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for - `past_key_values`. In other words, the `attention_mask` always has to have the length: - `len(past_key_values) + len(input_ids)` - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, input_ids_length)`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, input_ids_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `(batch_size, input_ids_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", - GPT2_START_DOCSTRING, -) -class TFGPT2Model(TFGPT2PreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFGPT2MainLayer(config, name="transformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPastAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have - their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past`). Set to `False` during training, `True` during generation - """ - - outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -@add_start_docstrings( - """ - The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - GPT2_START_DOCSTRING, -) -class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFGPT2MainLayer(config, name="transformer") - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids") - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - inputs = tf.expand_dims(inputs[:, -1], -1) - if token_type_ids is not None: - token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1) - - position_ids = kwargs.get("position_ids") - attention_mask = kwargs.get("attention_mask") - - if attention_mask is not None and position_ids is None: - position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) - if past_key_values: - position_ids = tf.expand_dims(position_ids[:, -1], -1) - - return { - "input_ids": inputs, - "attention_mask": attention_mask, - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "token_type_ids": token_type_ids, - } - - @unpack_inputs - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have - their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past`). Set to `False` during training, `True` during generation - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - - transformer_outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - logits = tf.matmul(hidden_states, self.transformer.wte.weights, transpose_b=True) - - loss = None - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels, shifted_logits) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithCrossAttentions( - loss=loss, - logits=logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -@add_start_docstrings( - """ - The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for - RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the - input embeddings, the classification head takes as input the input of a specified classification token index in the - input sequence). - """, - GPT2_START_DOCSTRING, -) -class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - config.num_labels = 1 - self.transformer = TFGPT2MainLayer(config, name="transformer") - self.multiple_choice_head = TFSequenceSummary( - config, initializer_range=config.initializer_range, name="multiple_choice_head" - ) - - @unpack_inputs - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - mc_token_ids: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFGPT2DoubleHeadsModelOutput | tuple[tf.Tensor]: - r""" - mc_token_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): - Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - - 1]`. - - Return: - - Examples: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFGPT2DoubleHeadsModel - - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - >>> model = TFGPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") - - >>> # Add a [CLS] to the vocabulary (we should train it also!) - >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) - - >>> embedding_layer = model.resize_token_embeddings( - ... len(tokenizer) - ... ) # Update the model embeddings with the new vocabulary size - - >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] - >>> encoded_choices = [tokenizer.encode(s) for s in choices] - >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] - - >>> input_ids = tf.constant(encoded_choices)[None, :] # Batch size: 1, number of choices: 2 - >>> mc_token_ids = tf.constant([cls_token_location]) # Batch size: 1 - - >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) - >>> lm_prediction_scores, mc_prediction_scores = outputs[:2] - ```""" - - if input_ids is not None: - input_shapes = shape_list(input_ids) - else: - input_shapes = shape_list(inputs_embeds)[:-1] - - seq_length = input_shapes[-1] - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - transformer_outputs = self.transformer( - input_ids=flat_input_ids, - past_key_values=past_key_values, - attention_mask=flat_attention_mask, - token_type_ids=flat_token_type_ids, - position_ids=flat_position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=None, - encoder_attention_mask=None, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) - if return_dict and output_hidden_states: - # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the - # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged) - all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,) - else: - all_hidden_states = None - lm_logits = tf.matmul(hidden_states, self.transformer.wte.weights, transpose_b=True) - mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) - mc_logits = tf.squeeze(mc_logits, axis=-1) - - if not return_dict: - return (lm_logits, mc_logits) + transformer_outputs[1:] - - return TFGPT2DoubleHeadsModelOutput( - logits=lm_logits, - mc_logits=mc_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=all_hidden_states, - attentions=transformer_outputs.attentions, - ) - - @property - def input_signature(self): - return { - "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"), - "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"), - "mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="mc_token_ids"), - } - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "multiple_choice_head", None) is not None: - with tf.name_scope(self.multiple_choice_head.name): - self.multiple_choice_head.build(None) - - -@add_start_docstrings( - """ - The GPT2 Model transformer with a sequence classification head on top (linear layer). - - [`TFGPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-1) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - GPT2_START_DOCSTRING, -) -class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.score = keras.layers.Dense( - config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="score", - use_bias=False, - ) - self.transformer = TFGPT2MainLayer(config, name="transformer") - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint="microsoft/DialogRPT-updown", - output_type=TFSequenceClassifierOutputWithPast, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutputWithPast | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - logits_shape = shape_list(logits) - batch_size = logits_shape[0] - - if self.config.pad_token_id is None: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - else: - if input_ids is not None: - token_indices = tf.range(shape_list(input_ids)[-1]) - non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype) - last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1) - else: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - loss = None - - pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1) - - if labels is not None: - if self.config.pad_token_id is None and logits_shape[0] != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - - loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels])) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "score", None) is not None: - with tf.name_scope(self.score.name): - self.score.build([None, None, self.config.n_embd]) - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -__all__ = [ - "TFGPT2DoubleHeadsModel", - "TFGPT2ForSequenceClassification", - "TFGPT2LMHeadModel", - "TFGPT2MainLayer", - "TFGPT2Model", - "TFGPT2PreTrainedModel", -] diff --git a/src/transformers/models/gpt2/tokenization_gpt2_tf.py b/src/transformers/models/gpt2/tokenization_gpt2_tf.py deleted file mode 100644 index 145a45da0db6..000000000000 --- a/src/transformers/models/gpt2/tokenization_gpt2_tf.py +++ /dev/null @@ -1,119 +0,0 @@ -import os -from typing import Optional, Union - -import tensorflow as tf -from tensorflow_text import pad_model_inputs - -from ...modeling_tf_utils import keras -from ...utils.import_utils import is_keras_nlp_available, requires -from .tokenization_gpt2 import GPT2Tokenizer - - -if is_keras_nlp_available(): - from keras_nlp.tokenizers import BytePairTokenizer - - -@requires(backends=("keras_nlp",)) -class TFGPT2Tokenizer(keras.layers.Layer): - """ - This is an in-graph tokenizer for GPT2. It should be initialized similarly to other tokenizers, using the - `from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings - from an existing standard tokenizer object. - - In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run - when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options - than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes - straight from `tf.string` inputs to outputs. - - Args: - vocab (dict[str, int]): Vocabulary dict for Byte Pair Tokenizer - merges (list[str]): Merges list for Byte Pair Tokenizer - """ - - def __init__( - self, - vocab: dict[str, int], - merges: list[str], - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - ): - super().__init__() - self.pad_token_id = pad_token_id - self.max_length = max_length - self.vocab = vocab - self.merges = merges - - self.tf_tokenizer = BytePairTokenizer(vocab, merges, sequence_length=max_length) - - @classmethod - def from_tokenizer(cls, tokenizer: GPT2Tokenizer, *args, **kwargs): - """Creates TFGPT2Tokenizer from GPT2Tokenizer - - Args: - tokenizer (GPT2Tokenizer) - - Examples: - - ```python - from transformers import AutoTokenizer, TFGPT2Tokenizer - - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - tf_tokenizer = TFGPT2Tokenizer.from_tokenizer(tokenizer) - ``` - """ - merges = [" ".join(m) for m in tokenizer.bpe_ranks] - vocab = tokenizer.get_vocab() - return cls(vocab, merges, *args, **kwargs) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs): - """Creates TFGPT2Tokenizer from pretrained GPT2Tokenizer - - Args: - pretrained_model_name_or_path (Union[str, os.PathLike]): Path to pretrained model - - Examples: - - ```python - from transformers import TFGPT2Tokenizer - - tf_tokenizer = TFGPT2Tokenizer.from_pretrained("openai-community/gpt2") - ``` - """ - tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) - return cls.from_tokenizer(tokenizer, *init_inputs, **kwargs) - - @classmethod - def from_config(cls, config): - """Creates TFGPT2Tokenizer from configurations - - Args: - config (Dict): Dictionary with keys such as stated in `get_config`. - """ - return cls(**config) - - def get_config(self): - return { - "vocab": self.vocab, - "merges": self.merges, - "max_length": self.max_length, - "pad_token_id": self.pad_token_id, - } - - def call(self, x, max_length: Optional[int] = None): - input_ids = self.tf_tokenizer(x) - attention_mask = tf.ones_like(input_ids) - - if self.pad_token_id is not None: - # pad the tokens up to max length - max_length = max_length if max_length is not None else self.max_length - - if max_length is not None: - input_ids, attention_mask = pad_model_inputs( - input_ids, max_seq_length=max_length, pad_value=self.pad_token_id - ) - - return {"attention_mask": attention_mask, "input_ids": input_ids} - - -__all__ = ["TFGPT2Tokenizer"] diff --git a/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py b/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py deleted file mode 100644 index 3db22857293c..000000000000 --- a/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py +++ /dev/null @@ -1,71 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Eleuther AI and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert GPT Neo checkpoint.""" - -import argparse -import json - -from transformers import GPTNeoConfig, GPTNeoForCausalLM, load_tf_weights_in_gpt_neo -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): - # Initialise PyTorch model - config_json = json.load(open(config_file, "r")) - config = GPTNeoConfig( - hidden_size=config_json["n_embd"], - num_layers=config_json["n_layer"], - num_heads=config_json["n_head"], - attention_types=config_json["attention_types"], - max_position_embeddings=config_json["n_positions"], - resid_dropout=config_json["res_dropout"], - embed_dropout=config_json["embed_dropout"], - attention_dropout=config_json["attn_dropout"], - ) - print(f"Building PyTorch model from configuration: {config}") - model = GPTNeoForCausalLM(config) - - # Load weights from tf checkpoint - load_tf_weights_in_gpt_neo(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained mesh-tf model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py deleted file mode 100644 index a6cdc50b359b..000000000000 --- a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py +++ /dev/null @@ -1,687 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Eleuther AI and The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_gpt_neo import GPTNeoConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "GPTNeoConfig" -_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B" - - -GPT_NEO_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`GPTNeoConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -GPT_NEO_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class FlaxGPTNeoSelfAttention(nn.Module): - config: GPTNeoConfig - attention_type: str - dtype: jnp.dtype = jnp.float32 - - def setup(self): - config = self.config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and " - f"`num_heads`: {self.num_heads})." - ) - - self.attn_dropout = nn.Dropout(config.attention_dropout) - self.resid_dropout = nn.Dropout(config.resid_dropout) - - dense = partial( - nn.Dense, - self.embed_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - self.q_proj, self.k_proj, self.v_proj = dense(use_bias=False), dense(use_bias=False), dense(use_bias=False) - self.out_proj = dense() - - self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") - if self.attention_type == "local": - self.causal_mask = self.causal_mask ^ jnp.tril(self.causal_mask, -config.window_size) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - attention_mask=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - query = self.q_proj(hidden_states) * jnp.sqrt(self.head_dim).astype(self.dtype) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - query = self._split_heads(query) - key = self._split_heads(key) - value = self._split_heads(value) - - query_length, key_length = query.shape[1], key.shape[1] - - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - - batch_size = hidden_states.shape[0] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - - dropout_rng = None - if not deterministic and self.config.attention_dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.has_variable("cache", "cached_key") or init_cache: - key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) - - # transform boolean mask into float mask - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - - # usual dot product attention - attn_weights = dot_product_attention_weights( - query, - key, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_dropout, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - attn_output = self.resid_dropout(attn_output, deterministic=deterministic) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxGPTNeoAttention(nn.Module): - config: GPTNeoConfig - layer_id: int = 0 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - attention_type = self.config.attention_layers[self.layer_id] - self.attention = FlaxGPTNeoSelfAttention(self.config, attention_type, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - return self.attention( - hidden_states, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - - -class FlaxGPTNeoMLP(nn.Module): - config: GPTNeoConfig - intermediate_size: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - embed_dim = self.config.hidden_size - kernel_init = jax.nn.initializers.normal(self.config.initializer_range) - self.c_fc = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init) - self.c_proj = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init) - self.act = ACT2FN[self.config.activation_function] - self.dropout = nn.Dropout(rate=self.config.resid_dropout) - - def __call__(self, hidden_states, deterministic: bool = True): - hidden_states = self.c_fc(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxGPTNeoBlock(nn.Module): - config: GPTNeoConfig - layer_id: int = 0 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - hidden_size = self.config.hidden_size - inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * hidden_size - - self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - self.attn = FlaxGPTNeoAttention(self.config, layer_id=self.layer_id, dtype=self.dtype) - self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - self.mlp = FlaxGPTNeoMLP(self.config, inner_dim, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - outputs = self.attn( - hidden_states, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - # residual connection - attn_output = outputs[0] - hidden_states = attn_output + residual - - residual = hidden_states - hidden_states = self.ln_2(hidden_states) - feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic) - # residual connection - hidden_states = residual + feed_forward_hidden_states - - return (hidden_states,) + outputs[1:] - - -class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = GPTNeoConfig - base_model_prefix = "transformer" - module_class: nn.Module = None - - def __init__( - self, - config: GPTNeoConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length)) - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING) - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - params: Optional[dict] = None, - past_key_values: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - batch_size, sequence_length = input_ids.shape - - if position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTNeoAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - not train, - False, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -class FlaxGPTNeoBlockCollection(nn.Module): - config: GPTNeoConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.blocks = [ - FlaxGPTNeoBlock(self.config, layer_id=i, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for block in self.blocks: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = block( - hidden_states, - attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - # this contains possible `None` values - `FlaxGPTNeoModule` will filter them out - outputs = (hidden_states, all_hidden_states, all_attentions) - - return outputs - - -class FlaxGPTNeoModule(nn.Module): - config: GPTNeoConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.embed_dim = self.config.hidden_size - embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) - self.wte = nn.Embed( - self.config.vocab_size, - self.embed_dim, - embedding_init=embedding_init, - ) - self.wpe = nn.Embed( - self.config.max_position_embeddings, - self.embed_dim, - embedding_init=embedding_init, - ) - self.dropout = nn.Dropout(rate=self.config.embed_dropout) - self.h = FlaxGPTNeoBlockCollection(self.config, dtype=self.dtype) - self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - deterministic=True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - input_embeds = self.wte(input_ids.astype("i4")) - position_embeds = self.wpe(position_ids.astype("i4")) - - hidden_states = input_embeds + position_embeds - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - - outputs = self.h( - hidden_states, - attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.ln_f(hidden_states) - - hidden_states = outputs[0] - hidden_states = self.ln_f(hidden_states) - - if output_hidden_states: - all_hidden_states = outputs[1] + (hidden_states,) - outputs = (hidden_states, all_hidden_states) + outputs[2:] - else: - outputs = (hidden_states,) + outputs[1:] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=outputs[1], - attentions=outputs[-1], - ) - - -@add_start_docstrings( - "The bare GPTNeo Model transformer outputting raw hidden-states without any specific head on top.", - GPT_NEO_START_DOCSTRING, -) -class FlaxGPTNeoModel(FlaxGPTNeoPreTrainedModel): - module_class = FlaxGPTNeoModule - - -append_call_sample_docstring(FlaxGPTNeoModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC) - - -class FlaxGPTNeoForCausalLMModule(nn.Module): - config: GPTNeoConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.transformer = FlaxGPTNeoModule(self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - outputs = self.transformer( - input_ids, - attention_mask, - position_ids, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T - lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - - -@add_start_docstrings( - """ - The GPTNeo Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - GPT_NEO_START_DOCSTRING, -) -class FlaxGPTNeoForCausalLM(FlaxGPTNeoPreTrainedModel): - module_class = FlaxGPTNeoForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since GPTNeo uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring(FlaxGPTNeoForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC) - - -__all__ = ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"] diff --git a/src/transformers/models/gptj/modeling_flax_gptj.py b/src/transformers/models/gptj/modeling_flax_gptj.py deleted file mode 100644 index 12ea7a4fffb4..000000000000 --- a/src/transformers/models/gptj/modeling_flax_gptj.py +++ /dev/null @@ -1,721 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The EleutherAI and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_gptj import GPTJConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "gptj" -_CONFIG_FOR_DOC = "GPTJConfig" - - -GPTJ_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`GPTJConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -GPTJ_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def create_sinusoidal_positions(num_pos, dim): - inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim)) - sinusoid_inp = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") - sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp) - - sentinel = dim // 2 + dim % 2 - out = np.zeros((num_pos, dim)) - out[:, 0:sentinel] = sin - out[:, sentinel:] = cos - - return jnp.array(out) - - -def rotate_every_two(tensor): - rotate_half_tensor = jnp.stack((-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1) - rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,)) - return rotate_half_tensor - - -def apply_rotary_pos_emb(tensor, sincos): - sin_pos, cos_pos = sincos - sin_pos = sin_pos[:, :, None, :].repeat(2, 3) - cos_pos = cos_pos[:, :, None, :].repeat(2, 3) - return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos) - - -class FlaxGPTJAttention(nn.Module): - config: GPTJConfig - dtype: jnp.dtype = jnp.float32 - causal: bool = True - is_cross_attention: bool = False - - def setup(self): - config = self.config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - - self.rotary_dim = config.rotary_dim - - dense = partial( - nn.Dense, - self.embed_dim, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() - self.out_proj = dense() - - self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) - - self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") - - pos_embd_dim = self.rotary_dim or self.embed_dim - self.embed_positions = create_sinusoidal_positions(config.max_position_embeddings, pos_embd_dim) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key - # positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - attention_mask, - position_ids, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - query = self._split_heads(query) - key = self._split_heads(key) - value = self._split_heads(value) - - sincos = jnp.take(self.embed_positions, position_ids, axis=0) - sincos = jnp.split(sincos, 2, axis=-1) - if self.rotary_dim is not None: - k_rot = key[:, :, :, : self.rotary_dim] - k_pass = key[:, :, :, self.rotary_dim :] - - q_rot = query[:, :, :, : self.rotary_dim] - q_pass = query[:, :, :, self.rotary_dim :] - - k_rot = apply_rotary_pos_emb(k_rot, sincos) - q_rot = apply_rotary_pos_emb(q_rot, sincos) - - key = jnp.concatenate([k_rot, k_pass], axis=-1) - query = jnp.concatenate([q_rot, q_pass], axis=-1) - else: - key = apply_rotary_pos_emb(key, sincos) - query = apply_rotary_pos_emb(query, sincos) - - query_length, key_length = query.shape[1], key.shape[1] - - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - - batch_size = hidden_states.shape[0] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - - dropout_rng = None - if not deterministic and self.config.attn_pdrop > 0.0: - dropout_rng = self.make_rng("dropout") - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.has_variable("cache", "cached_key") or init_cache: - key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) - - # transform boolean mask into float mask - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - - # usual dot product attention - attn_weights = dot_product_attention_weights( - query, - key, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attn_pdrop, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - attn_output = self.resid_dropout(attn_output, deterministic=deterministic) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxGPTJMLP(nn.Module): - config: GPTJConfig - intermediate_size: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - embed_dim = self.config.hidden_size - kernel_init = jax.nn.initializers.normal(self.config.initializer_range) - - self.fc_in = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init) - self.fc_out = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init) - - self.act = ACT2FN[self.config.activation_function] - self.dropout = nn.Dropout(rate=self.config.resid_pdrop) - - def __call__(self, hidden_states, deterministic: bool = True): - hidden_states = self.fc_in(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.fc_out(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxGPTJBlock(nn.Module): - config: GPTJConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - hidden_size = self.config.hidden_size - inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size - - self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - self.attn = FlaxGPTJAttention(self.config, dtype=self.dtype) - - self.mlp = FlaxGPTJMLP(self.config, inner_dim, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - - feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic) - # residual connection - hidden_states = attn_output + feed_forward_hidden_states + residual - - return (hidden_states,) + attn_outputs[1:] - - -class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = GPTJConfig - base_model_prefix = "transformer" - module_class: nn.Module = None - - def __init__( - self, - config: GPTJConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) - - random_params = module_init_outputs["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length)) - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return init_variables["cache"] - - @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING) - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - params: Optional[dict] = None, - past_key_values: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - batch_size, sequence_length = input_ids.shape - - if position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - not train, - False, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -class FlaxGPTJBlockCollection(nn.Module): - config: GPTJConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.blocks = [ - FlaxGPTJBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for block in self.blocks: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = block( - hidden_states, - attention_mask, - position_ids=position_ids, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - # this contains possible `None` values - `FlaxGPTJModule` will filter them out - outputs = (hidden_states, all_hidden_states, all_attentions) - - return outputs - - -class FlaxGPTJModule(nn.Module): - config: GPTJConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.embed_dim = self.config.hidden_size - - self.wte = nn.Embed( - self.config.vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.dropout = nn.Dropout(rate=self.config.embd_pdrop) - self.h = FlaxGPTJBlockCollection(self.config, dtype=self.dtype) - self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - deterministic=True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - input_embeds = self.wte(input_ids.astype("i4")) - - hidden_states = self.dropout(input_embeds, deterministic=deterministic) - - outputs = self.h( - hidden_states, - attention_mask, - position_ids=position_ids, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.ln_f(hidden_states) - - if output_hidden_states: - all_hidden_states = outputs[1] + (hidden_states,) - outputs = (hidden_states, all_hidden_states) + outputs[2:] - else: - outputs = (hidden_states,) + outputs[1:] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=outputs[1], - attentions=outputs[-1], - ) - - -@add_start_docstrings( - "The bare GPTJ Model transformer outputting raw hidden-states without any specific head on top.", - GPTJ_START_DOCSTRING, -) -class FlaxGPTJModel(FlaxGPTJPreTrainedModel): - module_class = FlaxGPTJModule - - -append_call_sample_docstring( - FlaxGPTJModel, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxGPTJForCausalLMModule(nn.Module): - config: GPTJConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.transformer = FlaxGPTJModule(self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.config.vocab_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - outputs = self.transformer( - input_ids, - attention_mask, - position_ids, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T - lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - - -@add_start_docstrings( - """ - The GPTJ Model transformer with a language modeling head on top. - """, - GPTJ_START_DOCSTRING, -) -class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel): - module_class = FlaxGPTJForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since GPTJ uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxGPTJForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutput, - _CONFIG_FOR_DOC, -) - - -__all__ = ["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"] diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py deleted file mode 100644 index 0ec32258223c..000000000000 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ /dev/null @@ -1,1094 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The EleutherAI and HuggingFace Teams. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 GPT-J model.""" - -from __future__ import annotations - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...file_utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, -) -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPast, - TFCausalLMOutputWithPast, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutputWithPast, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFSharedEmbeddings, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import logging -from .configuration_gptj import GPTJConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-j-6B" -_CONFIG_FOR_DOC = "GPTJConfig" - - -def create_sinusoidal_positions(num_pos: int, dim: int) -> tf.Tensor: - inv_freq = tf.cast(1.0 / (10000 ** (tf.range(0, dim, 2) / dim)), tf.float32) - sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", tf.range(num_pos, dtype=tf.float32), inv_freq), tf.float32) - sin, cos = tf.sin(sinusoid_inp), tf.cos(sinusoid_inp) - out = tf.concat((sin, cos), axis=1) - return out - - -def rotate_every_two(x: tf.Tensor) -> tf.Tensor: - rotate_half_tensor = tf.stack((-x[:, :, :, 1::2], x[:, :, :, ::2]), axis=-1) - new_shape = shape_list(rotate_half_tensor)[:-2] + [tf.math.reduce_prod(shape_list(rotate_half_tensor)[-2:])] - rotate_half_tensor = tf.reshape(rotate_half_tensor, new_shape) - return rotate_half_tensor - - -def apply_rotary_pos_emb(tensor: tf.Tensor, sincos: tf.Tensor) -> tf.Tensor: - sin_pos, cos_pos = sincos - sin_pos = tf.repeat(sin_pos[:, :, None, :], 2, 3) - cos_pos = tf.repeat(cos_pos[:, :, None, :], 2, 3) - return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos) - - -class TFGPTJAttention(keras.layers.Layer): - def __init__(self, config: GPTJConfig, **kwargs): - super().__init__(**kwargs) - - self.embed_dim = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_attention_heads - if self.head_dim * self.num_attention_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" - f" `num_attention_heads`: {self.num_attention_heads})." - ) - self.scale_attn = self.head_dim**0.5 - self.rotary_dim = config.rotary_dim - - self.attn_dropout = keras.layers.Dropout(config.attn_pdrop) - self.resid_dropout = keras.layers.Dropout(config.resid_pdrop) - - self.q_proj = keras.layers.Dense( - self.embed_dim, - use_bias=False, - kernel_initializer=get_initializer(config.initializer_range), - name="q_proj", - ) - self.k_proj = keras.layers.Dense( - self.embed_dim, - use_bias=False, - kernel_initializer=get_initializer(config.initializer_range), - name="k_proj", - ) - self.v_proj = keras.layers.Dense( - self.embed_dim, - use_bias=False, - kernel_initializer=get_initializer(config.initializer_range), - name="v_proj", - ) - self.out_proj = keras.layers.Dense( - self.embed_dim, - use_bias=False, - kernel_initializer=get_initializer(config.initializer_range), - name="out_proj", - ) - - self.max_positions = config.max_position_embeddings - self.lower_triangle_mask = tf.reshape( - tf.cast(tf.experimental.numpy.tril(tf.ones((self.max_positions, self.max_positions))), tf.int8), - (1, 1, self.max_positions, self.max_positions), - ) - pos_embd_dim = self.rotary_dim or self.embed_dim - self.embed_positions = create_sinusoidal_positions(self.max_positions, pos_embd_dim) - - def get_causal_mask(self, key_length, query_length) -> tf.Tensor: - return tf.cast(self.lower_triangle_mask[:, :, key_length - query_length : key_length, :key_length], tf.bool) - - @staticmethod - def get_masked_bias(dtype: tf.DType) -> tf.Tensor: - return tf.cast(tf.constant(-1e9), dtype) - - def _split_heads(self, hidden_states: tf.Tensor, rotary: bool) -> tf.Tensor: - """ - Splits hidden dim into attn_head_size and num_attention_heads - """ - new_shape = shape_list(hidden_states)[:-1] + [self.num_attention_heads, self.head_dim] - hidden_states = tf.reshape(hidden_states, new_shape) - if rotary: - return hidden_states - if len(shape_list(hidden_states)) == 4: - return tf.transpose(hidden_states, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) - if len(shape_list(hidden_states)) == 5: - return tf.transpose(hidden_states, (0, 1, 3, 2, 4)) # (batch, blocks, head, block_length, head_features) - raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(shape_list(hidden_states))}") - - def _merge_heads(self, hidden_states: tf.Tensor) -> tf.Tensor: - """ - Merges attn_head_size dim and num_attn_heads dim into hidden dim - """ - if len(shape_list(hidden_states)) == 4: - hidden_states = tf.transpose(hidden_states, (0, 2, 1, 3)) - elif len(shape_list(hidden_states)) == 5: - hidden_states = tf.transpose(hidden_states, (0, 1, 3, 2, 4)) - else: - raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(shape_list(hidden_states))}") - new_shape = shape_list(hidden_states)[:-2] + [self.num_attention_heads * self.head_dim] - return tf.reshape(hidden_states, new_shape) - - def _attn( - self, - query: tf.Tensor, - key: tf.Tensor, - value: tf.Tensor, - attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - ) -> tuple[tf.Tensor, tf.Tensor]: - # compute causal mask from causal mask buffer - query_length, key_length = shape_list(query)[-2], shape_list(key)[-2] - causal_mask = self.get_causal_mask(key_length, query_length) - - # Keep the attention weights computation in fp32 to avoid overflow issues - query = tf.cast(query, tf.float32) - key = tf.cast(key, tf.float32) - - attn_weights = tf.matmul(query, key, transpose_b=True) - attn_weights = tf.where(causal_mask, attn_weights, self.get_masked_bias(attn_weights.dtype)) - - attn_weights = attn_weights / self.scale_attn - - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask - - attn_weights = stable_softmax(attn_weights, axis=-1) - attn_weights = tf.cast(attn_weights, value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = tf.matmul(attn_weights, value) - - return attn_output, attn_weights - - def call( - self, - hidden_states: tf.Tensor, - layer_past: tuple[tf.Tensor, tf.Tensor] | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - query = self._split_heads(query, True) - key = self._split_heads(key, True) - value = self._split_heads(value, False) - - sincos = tf.cast(tf.gather(self.embed_positions, position_ids, axis=0), hidden_states.dtype) - sincos = tf.split(sincos, 2, axis=-1) - if self.rotary_dim is not None: - k_rot = key[:, :, :, : self.rotary_dim] - k_pass = key[:, :, :, self.rotary_dim :] - - q_rot = query[:, :, :, : self.rotary_dim] - q_pass = query[:, :, :, self.rotary_dim :] - - k_rot = apply_rotary_pos_emb(k_rot, sincos) - q_rot = apply_rotary_pos_emb(q_rot, sincos) - - key = tf.concat((k_rot, k_pass), axis=-1) - query = tf.concat((q_rot, q_pass), axis=-1) - else: - key = apply_rotary_pos_emb(key, sincos) - query = apply_rotary_pos_emb(query, sincos) - - key = tf.transpose(key, (0, 2, 1, 3)) - query = tf.transpose(query, (0, 2, 1, 3)) - - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] - key = tf.concat((past_key, key), axis=-2) - value = tf.concat((past_value, value), axis=-2) - - if use_cache is True: - present = (key, value) - else: - present = None - - # compute self-attention: V x Softmax(QK^T) - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -class TFGPTJMLP(keras.layers.Layer): - def __init__(self, intermediate_size: int, config: GPTJConfig, **kwargs): - super().__init__(**kwargs) - embed_dim = config.n_embd - - self.fc_in = keras.layers.Dense( - intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="fc_in" - ) - self.fc_out = keras.layers.Dense( - embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="fc_out" - ) - - self.act = get_tf_activation(config.activation_function) - self.dropout = keras.layers.Dropout(config.embd_pdrop) - self.embed_dim = config.n_embd - self.intermediate_size = intermediate_size - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.fc_in(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.fc_out(hidden_states) - hidden_states = self.dropout(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "fc_in", None) is not None: - with tf.name_scope(self.fc_in.name): - self.fc_in.build([None, None, self.embed_dim]) - if getattr(self, "fc_out", None) is not None: - with tf.name_scope(self.fc_out.name): - self.fc_out.build([None, None, self.intermediate_size]) - - -class TFGPTJBlock(keras.layers.Layer): - def __init__(self, config: GPTJConfig, **kwargs): - super().__init__(**kwargs) - inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd - self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1") - self.attn = TFGPTJAttention(config, name="attn") - self.mlp = TFGPTJMLP(inner_dim, config, name="mlp") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - layer_past: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( - hidden_states=hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) # attn_outputs: attn_output, present, (attentions) - attn_output = attn_outputs[0] - outputs = attn_outputs[1:] - - feed_forward_hidden_states = self.mlp(hidden_states) - hidden_states = attn_output + feed_forward_hidden_states + residual - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - return outputs # hidden_states, present, (attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "ln_1", None) is not None: - with tf.name_scope(self.ln_1.name): - self.ln_1.build([None, None, self.config.n_embd]) - if getattr(self, "attn", None) is not None: - with tf.name_scope(self.attn.name): - self.attn.build(None) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - - -@keras_serializable -class TFGPTJMainLayer(keras.layers.Layer): - config_class = GPTJConfig - - def __init__(self, config: GPTJConfig, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - self.config = config - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.use_cache = config.use_cache - self.return_dict = config.use_return_dict - - self.num_hidden_layers = config.n_layer - self.n_embd = config.n_embd - self.n_positions = config.n_positions - self.initializer_range = config.initializer_range - - self.wte = TFSharedEmbeddings( - config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte" - ) - self.drop = keras.layers.Dropout(config.embd_pdrop) - self.h = [TFGPTJBlock(config, name=f"h_._{i}") for i in range(config.n_layer)] - self.ln_f = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f") - self.embed_dim = config.n_embd - - def get_input_embeddings(self): - return self.wte - - def set_input_embeddings(self, value: tf.Tensor): - self.wte.weight = value - self.wte.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ) -> TFBaseModelOutputWithPast | tuple[tf.Tensor]: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if past_key_values is None: - past_length = 0 - past_key_values = [None] * len(self.h) - else: - past_length = shape_list(past_key_values[0][0])[-2] - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) - - if attention_mask is not None: - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(attention_mask) - attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - one_cst = tf.constant(1.0) - attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype) - attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0)) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.num_hidden_layers - # head_mask = tf.constant([0] * self.num_hidden_layers) - - position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.wte.vocab_size) - inputs_embeds = self.wte(input_ids, mode="embedding") - - if token_type_ids is not None: - token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) - token_type_embeds = self.wte(token_type_ids, mode="embedding") - else: - token_type_embeds = tf.constant(0.0) - - token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype) - hidden_states = inputs_embeds + token_type_embeds - hidden_states = self.drop(hidden_states, training=training) - - output_shape = input_shape + [shape_list(hidden_states)[-1]] - - presents = () if use_cache else None - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) - - outputs = block( - hidden_states=hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - - hidden_states = outputs[0] - if use_cache: - presents = presents + (outputs[1],) - - if output_attentions: - all_attentions = all_attentions + (outputs[2 if use_cache else 1],) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = tf.reshape(hidden_states, output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if output_attentions: - # let the number of heads free (-1) so we can extract attention even after head pruning - attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] - all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "wte", None) is not None: - with tf.name_scope(self.wte.name): - self.wte.build(None) - if getattr(self, "ln_f", None) is not None: - with tf.name_scope(self.ln_f.name): - self.ln_f.build([None, None, self.embed_dim]) - if getattr(self, "h", None) is not None: - for layer in self.h: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFGPTJPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = GPTJConfig - base_model_prefix = "transformer" - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias"] - - -GPTJ_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`GPTJConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -GPTJ_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past` is `None` else `past[0].shape[-2]` (`sequence_length` of - input past key value states). Indices of input sequence tokens in the vocabulary. - - If `past` is used, only input IDs that do not have their past calculated should be passed as `input_ids`. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - past_key_values (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see - `past` output below). Can be used to speed up sequential decoding. The token ids which have their past - given to this model should not be passed as input ids as they have already been computed. - attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, input_ids_length)`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, input_ids_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `(batch_size, input_ids_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare GPT-J Model transformer outputting raw hidden-states without any specific head on top.", - GPTJ_START_DOCSTRING, -) -class TFGPTJModel(TFGPTJPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFGPTJMainLayer(config, name="transformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPast, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPast | tuple[tf.Tensor]: - r""" - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past`). Set to `False` during training, `True` during generation - """ - - outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -@add_start_docstrings( - """ - The GPT-J Model transformer with a language modeling head on top. - """, - GPTJ_START_DOCSTRING, -) -class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFGPTJMainLayer(config, name="transformer") - self.lm_head = keras.layers.Dense( - config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head" - ) - self.config = config - - def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids") - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - inputs = tf.expand_dims(inputs[:, -1], -1) - if token_type_ids is not None: - token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1) - - position_ids = kwargs.get("position_ids") - attention_mask = kwargs.get("attention_mask") - - if attention_mask is not None and position_ids is None: - position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) - if past_key_values: - position_ids = tf.expand_dims(position_ids[:, -1], -1) - - return { - "input_ids": inputs, - "attention_mask": attention_mask, - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "token_type_ids": token_type_ids, - } - - @unpack_inputs - @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutputWithPast, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - labels: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFCausalLMOutputWithPast | tuple[tf.Tensor]: - r""" - labels (`np.ndarray` or `tf.Tensor` of shape `(batch_size, input_ids_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - - transformer_outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = lm_logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels, shifted_logits) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build([None, None, self.config.n_embd]) - - -@add_start_docstrings( - """ - The GPT-J Model transformer with a sequence classification head on top (linear layer). - - [`GPTJForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT, GPT-2, GPT-Neo) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - GPTJ_START_DOCSTRING, -) -class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassificationLoss): - _keys_to_ignore_on_load_missing = [r"h.\d+.attn.masked_bias", r"h.\d+.attn.bias", r"lm_head.weight"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.transformer = TFGPTJMainLayer(config, name="transformer") - self.score = keras.layers.Dense( - self.num_labels, - use_bias=False, - kernel_initializer=get_initializer(config.initializer_range), - name="score", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutputWithPast, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - labels: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutputWithPast | tuple[tf.Tensor]: - r""" - labels (`np.ndarray` or `tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - if labels is not None and self.config.pad_token_id is None and input_ids.shape[0] != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - - transformer_outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - logits_shape = shape_list(logits) - batch_size = logits_shape[0] - - if self.config.pad_token_id is None: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - else: - if input_ids is not None: - token_indices = tf.range(shape_list(input_ids)[-1]) - non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype) - last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1) - else: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - loss = None - - pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1) - - if labels is not None: - if self.config.pad_token_id is None and logits_shape[0] != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - - loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels])) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "score", None) is not None: - with tf.name_scope(self.score.name): - self.score.build([None, None, self.config.n_embd]) - - -@add_start_docstrings( - """ - The GPT-J Model transformer with a span classification head on top for extractive question-answering tasks like - SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - GPTJ_START_DOCSTRING, -) -class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss): - _keys_to_ignore_on_load_missing = [r"h.\d+.attn.masked_bias", r"h.\d+.attn.bias", r"lm_head.weight"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.transformer = TFGPTJMainLayer(config, name="transformer") - self.qa_outputs = keras.layers.Dense( - self.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`np.ndarray` or `tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`np.ndarray` or `tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - - transformer_outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = transformer_outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - - loss = None - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + transformer_outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFGPTJForCausalLM", - "TFGPTJForQuestionAnswering", - "TFGPTJForSequenceClassification", - "TFGPTJModel", - "TFGPTJPreTrainedModel", -] diff --git a/src/transformers/models/groupvit/modeling_tf_groupvit.py b/src/transformers/models/groupvit/modeling_tf_groupvit.py deleted file mode 100644 index 1c999dca5f48..000000000000 --- a/src/transformers/models/groupvit/modeling_tf_groupvit.py +++ /dev/null @@ -1,2141 +0,0 @@ -# coding=utf-8 -# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 GroupViT model.""" - -from __future__ import annotations - -import collections.abc -import math -from dataclasses import dataclass -from typing import Any - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling -from ...modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_tensorflow_probability_available, - logging, - replace_return_docstrings, -) -from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig - - -logger = logging.get_logger(__name__) - -# soft dependency -if is_tensorflow_probability_available(): - try: - import tensorflow_probability as tfp - - # On the first call, check whether a compatible version of TensorFlow is installed - # TensorFlow Probability depends on a recent stable release of TensorFlow - _ = tfp.distributions.Normal(loc=0.0, scale=1.0) - except ImportError: - logger.error( - "GroupViT models are not usable since `tensorflow_probability` can't be loaded. " - "It seems you have `tensorflow_probability` installed with the wrong tensorflow version." - "Please try to reinstall it following the instructions here: https://github.com/tensorflow/probability." - ) -else: - try: - import tensorflow_probability as tfp - - # On the first call, check whether a compatible version of TensorFlow is installed - # TensorFlow Probability depends on a recent stable release of TensorFlow - _ = tfp.distributions.Normal(loc=0.0, scale=1.0) - except ImportError: - pass - -_CHECKPOINT_FOR_DOC = "nvidia/groupvit-gcc-yfcc" - - -LARGE_NEGATIVE = -1e8 - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -# contrastive loss function, adapted from -# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html -def contrastive_loss(logits: tf.Tensor) -> tf.Tensor: - return tf.math.reduce_mean( - keras.metrics.sparse_categorical_crossentropy( - y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True - ) - ) - - -# Copied from transformers.models.clip.modeling_tf_clip.clip_loss with clip->groupvit -def groupvit_loss(similarity: tf.Tensor) -> tf.Tensor: - caption_loss = contrastive_loss(similarity) - image_loss = contrastive_loss(tf.transpose(similarity)) - return (caption_loss + image_loss) / 2.0 - - -def hard_softmax(logits: tf.Tensor, dim: int) -> tf.Tensor: - y_soft = stable_softmax(logits, dim) - # Straight through. - index = tf.argmax(y_soft, dim) - y_hard = tf.one_hot( - index, - depth=shape_list(logits)[dim], - # TensorFlow expects axis to be -1 or between [0, 3). But received: -2 - # This is why the following code snippet is used. - axis=range(len(shape_list(logits)))[dim], - dtype=y_soft.dtype, - ) - ret = y_hard - tf.stop_gradient(y_soft) + y_soft - - return ret - - -def gumbel_softmax(logits: tf.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> tf.Tensor: - gumbel_dist = tfp.distributions.Gumbel(0.0, 1.0) - gumbels = gumbel_dist.sample(tf.shape(logits), dtype=logits.dtype) - - gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) - y_soft = stable_softmax(gumbels, dim) - - if hard: - # Straight through. - index = tf.argmax(y_soft, dim) - y_hard = tf.one_hot( - index, - depth=shape_list(logits)[dim], - # TensorFlow expects axis to be -1 or between [0, 3). But received: -2 - # This is why the following code snippet is used. - axis=range(len(shape_list(logits)))[dim], - dtype=y_soft.dtype, - ) - ret = y_hard - tf.stop_gradient(y_soft) + y_soft - else: - # Reparametrization trick. - ret = y_soft - return ret - - -def resize_attention_map(attentions: tf.Tensor, height: int, width: int, align_corners: bool = False) -> tf.Tensor: - """ - Args: - attentions (`tf.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width] - height (`int`): height of the output attention map - width (`int`): width of the output attention map - align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`. - - Returns: - `tf.Tensor`: resized attention map of shape [batch_size, groups, height, width] - """ - - scale = (height * width // attentions.shape[2]) ** 0.5 - if height > width: - feat_width = int(np.round(width / scale)) - feat_height = shape_list(attentions)[2] // feat_width - else: - feat_height = int(np.round(height / scale)) - feat_width = shape_list(attentions)[2] // feat_height - - batch_size = shape_list(attentions)[0] - groups = shape_list(attentions)[1] # number of group token - # [batch_size, groups, height x width, groups] -> [batch_size, groups, height, width] - attentions = tf.reshape(attentions, (batch_size, groups, feat_height, feat_width)) - attentions = tf.transpose(attentions, perm=(0, 2, 3, 1)) - if align_corners: - attentions = tf.compat.v1.image.resize( - attentions, - size=(height, width), - method="bilinear", - align_corners=align_corners, - ) - else: - attentions = tf.image.resize(attentions, size=(height, width), method="bilinear") - attentions = tf.transpose(attentions, perm=(0, 3, 1, 2)) - return attentions - - -def get_grouping_from_attentions(attentions: tuple[tf.Tensor], hw_shape: tuple[int]) -> tf.Tensor: - """ - Args: - attentions (`tuple(tf.Tensor)`: tuple of attention maps returned by `TFGroupViTVisionTransformer` - hw_shape (`tuple(int)`): height and width of the output attention map - Returns: - `tf.Tensor`: the attention map of shape [batch_size, groups, height, width] - """ - - attn_maps = [] - prev_attn_masks = None - for attn_masks in attentions: - # [batch_size, num_groups, height x width] -> [batch_size, height x width, num_groups] - attn_masks = tf.transpose(attn_masks, perm=(0, 2, 1)) - if prev_attn_masks is None: - prev_attn_masks = attn_masks - else: - prev_attn_masks = tf.matmul(prev_attn_masks, attn_masks) - # [batch_size, height x width, num_groups] -> [batch_size, num_groups, height x width] -> [batch_size, num_groups, height, width] - cur_attn_map = resize_attention_map(tf.transpose(prev_attn_masks, perm=(0, 2, 1)), *hw_shape) - attn_maps.append(cur_attn_map) - - # [batch_size, num_groups, height, width] - final_grouping = attn_maps[-1] - - return tf.stop_gradient(final_grouping) - - -@dataclass -class TFGroupViTModelOutput(ModelOutput): - """ - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): - Contrastive loss for image-text similarity. - logits_per_image (`tf.Tensor` of shape `(image_batch_size, text_batch_size)`): - The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text - similarity scores. - logits_per_text (`tf.Tensor` of shape `(text_batch_size, image_batch_size)`): - The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image - similarity scores. - segmentation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): - Classification scores for each pixel. - - - - The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is - to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the - original image size as post-processing. You should always check your logits shape and resize as needed. - - - - text_embeds (`tf.Tensor` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of - [`TFGroupViTTextModel`]. - image_embeds (`tf.Tensor` of shape `(batch_size, output_dim`): - The image embeddings obtained by applying the projection layer to the pooled output of - [`TFGroupViTVisionModel`]. - text_model_output (`TFBaseModelOutputWithPooling`): - The output of the [`TFGroupViTTextModel`]. - vision_model_output (`TFBaseModelOutputWithPooling`): - The output of the [`TFGroupViTVisionModel`]. - """ - - loss: tf.Tensor | None = None - logits_per_image: tf.Tensor | None = None - logits_per_text: tf.Tensor | None = None - segmentation_logits: tf.Tensor | None = None - text_embeds: tf.Tensor | None = None - image_embeds: tf.Tensor | None = None - text_model_output: TFBaseModelOutputWithPooling = None - vision_model_output: TFBaseModelOutputWithPooling = None - - def to_tuple(self) -> tuple[Any]: - return tuple( - self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() - for k in self.keys() - ) - - -class TFGroupViTCrossAttentionLayer(keras.layers.Layer): - def __init__(self, config: GroupViTVisionConfig, **kwargs): - super().__init__(**kwargs) - self.attn = TFGroupViTAttention(config, name="attn") - self.norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm2") - self.mlp = TFGroupViTMLP(config, name="mlp") - self.norm_post = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_post") - self.config = config - - def call(self, query: tf.Tensor, key: tf.Tensor, training: bool = False) -> tf.Tensor: - x = query - x = x + self.attn(query, encoder_hidden_states=key)[0] - x = x + self.mlp(self.norm2(x)) - x = self.norm_post(x) - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attn", None) is not None: - with tf.name_scope(self.attn.name): - self.attn.build(None) - if getattr(self, "norm2", None) is not None: - with tf.name_scope(self.norm2.name): - self.norm2.build([None, None, self.config.hidden_size]) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "norm_post", None) is not None: - with tf.name_scope(self.norm_post.name): - self.norm_post.build([None, None, self.config.hidden_size]) - - -class TFGroupViTAssignAttention(keras.layers.Layer): - def __init__(self, config: GroupViTVisionConfig, **kwargs): - super().__init__(**kwargs) - self.scale = config.hidden_size**-0.5 - - self.q_proj = keras.layers.Dense(config.hidden_size, name="q_proj") - self.k_proj = keras.layers.Dense(config.hidden_size, name="k_proj") - self.v_proj = keras.layers.Dense(config.hidden_size, name="v_proj") - self.proj = keras.layers.Dense(config.hidden_size, name="proj") - self.assign_eps = config.assign_eps - self.config = config - - def get_attn(self, attn: tf.Tensor, gumbel: bool = True, hard: bool = True, training: bool = False) -> tf.Tensor: - if gumbel and training: - attn = gumbel_softmax(attn, dim=-2, hard=hard) - else: - if hard: - attn = hard_softmax(attn, dim=-2) - else: - attn = stable_softmax(attn, axis=-2) - - return attn - - def call(self, query: tf.Tensor, key: tf.Tensor, training: bool = False): - value = key - # [batch_size, query_length, channels] - query = self.q_proj(query) - - # [batch_size, key_length, channels] - key = self.k_proj(key) - - # [batch_size, key_length, channels] - value = self.v_proj(value) - - # [batch_size, query_length, key_length] - raw_attn = tf.matmul(query, key, transpose_b=True) * self.scale - - attn = self.get_attn(raw_attn, training=training) - soft_attn = self.get_attn(raw_attn, training=training, gumbel=False, hard=False) - - attn = attn / (tf.math.reduce_sum(attn, axis=-1, keepdims=True) + self.assign_eps) - - out = tf.matmul(attn, value) - - out = self.proj(out) - - return out, soft_attn - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.config.hidden_size]) - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.config.hidden_size]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.config.hidden_size]) - if getattr(self, "proj", None) is not None: - with tf.name_scope(self.proj.name): - self.proj.build([None, None, self.config.hidden_size]) - - -class TFGroupViTTokenAssign(keras.layers.Layer): - def __init__(self, config: GroupViTVisionConfig, num_group_token: int, num_output_group: int, **kwargs): - super().__init__(**kwargs) - self.num_output_group = num_output_group - # norm on group_tokens - self.norm_tokens = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_tokens") - assign_mlp_ratio = ( - config.assign_mlp_ratio - if isinstance(config.assign_mlp_ratio, collections.abc.Iterable) - else (config.assign_mlp_ratio, config.assign_mlp_ratio) - ) - tokens_dim, channels_dim = [int(x * config.hidden_size) for x in assign_mlp_ratio] - self.mlp_inter = TFGroupViTMixerMLP(config, num_group_token, tokens_dim, num_output_group, name="mlp_inter") - self.norm_post_tokens = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_post_tokens") - # norm on x - self.norm_x = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_x") - self.pre_assign_attn = TFGroupViTCrossAttentionLayer(config, name="pre_assign_attn") - - self.assign = TFGroupViTAssignAttention(config, name="assign") - self.norm_new_x = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_new_x") - self.mlp_channels = TFGroupViTMLP( - config, config.hidden_size, channels_dim, config.hidden_size, name="mlp_channels" - ) - self.config = config - - def project_group_token(self, group_tokens: tf.Tensor) -> tf.Tensor: - """ - Args: - group_tokens (tf.Tensor): group tokens, [batch_size, num_group_tokens, channels] - - Returns: - projected_group_tokens (tf.Tensor): [batch_size, num_output_groups, channels] - """ - # [B, num_output_groups, C] <- [B, num_group_tokens, C] - projected_group_tokens = self.mlp_inter(group_tokens) - projected_group_tokens = self.norm_post_tokens(projected_group_tokens) - return projected_group_tokens - - def call(self, image_tokens: tf.Tensor, group_tokens: tf.Tensor, training: bool = False): - """ - Args: - image_tokens (`tf.Tensor`): image tokens, of shape [batch_size, input_length, channels] - group_tokens (`tf.Tensor`): group tokens, [batch_size, num_group_tokens, channels] - """ - - group_tokens = self.norm_tokens(group_tokens) - image_tokens = self.norm_x(image_tokens) - # [batch_size, num_output_groups, channels] - projected_group_tokens = self.project_group_token(group_tokens) - projected_group_tokens = self.pre_assign_attn(projected_group_tokens, image_tokens) - new_image_tokens, attention = self.assign(projected_group_tokens, image_tokens) - new_image_tokens += projected_group_tokens - - new_image_tokens = new_image_tokens + self.mlp_channels(self.norm_new_x(new_image_tokens)) - - return new_image_tokens, attention - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "norm_tokens", None) is not None: - with tf.name_scope(self.norm_tokens.name): - self.norm_tokens.build([None, None, self.config.hidden_size]) - if getattr(self, "mlp_inter", None) is not None: - with tf.name_scope(self.mlp_inter.name): - self.mlp_inter.build(None) - if getattr(self, "norm_post_tokens", None) is not None: - with tf.name_scope(self.norm_post_tokens.name): - self.norm_post_tokens.build([None, None, self.config.hidden_size]) - if getattr(self, "norm_x", None) is not None: - with tf.name_scope(self.norm_x.name): - self.norm_x.build([None, None, self.config.hidden_size]) - if getattr(self, "pre_assign_attn", None) is not None: - with tf.name_scope(self.pre_assign_attn.name): - self.pre_assign_attn.build(None) - if getattr(self, "assign", None) is not None: - with tf.name_scope(self.assign.name): - self.assign.build(None) - if getattr(self, "norm_new_x", None) is not None: - with tf.name_scope(self.norm_new_x.name): - self.norm_new_x.build([None, None, self.config.hidden_size]) - if getattr(self, "mlp_channels", None) is not None: - with tf.name_scope(self.mlp_channels.name): - self.mlp_channels.build(None) - - -# Adapted from transformers.models.vit.modeling_tf_vit.TFViTPatchEmbeddings with ViT->GroupViT -class TFGroupViTPatchEmbeddings(keras.layers.Layer): - """ - This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial - `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a - Transformer. - """ - - def __init__(self, config: GroupViTConfig, **kwargs): - super().__init__(**kwargs) - image_size, patch_size = config.image_size, config.patch_size - num_channels = config.num_channels - # hidden_size is a member as it will be required in the call method - self.hidden_size = config.hidden_size - - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_patches = num_patches - self.num_channels = num_channels - self.config = config - - self.projection = keras.layers.Conv2D( - filters=self.hidden_size, - kernel_size=patch_size, - strides=patch_size, - padding="valid", - data_format="channels_last", - use_bias=True, - kernel_initializer=get_initializer(self.config.initializer_range), - bias_initializer="zeros", - name="projection", - ) - - def call( - self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False - ) -> tf.Tensor: - batch_size, num_channels, height, width = shape_list(pixel_values) - if tf.executing_eagerly() and num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - if ( - not interpolate_pos_encoding - and tf.executing_eagerly() - and (height != self.image_size[0] or width != self.image_size[1]) - ): - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." - ) - - # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - # shape = (batch_size, in_height, in_width, in_channels=num_channels) - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - - projection = self.projection(pixel_values) - - # Change the 2D spatial dimensions to a single temporal dimension. - # shape = (batch_size, num_patches, out_channels=embed_dim) - num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0]) - # In the TFGroupViTVisionEmbeddings the embeddings from this layer will be layer normalized - # LayerNormalization layer needs to have static last dimension (otherwise the test_keras_save_load fails with symbolic tensors) - # This is why we have used the hidden_size in the reshape method - embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, self.hidden_size)) - - return embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, None, self.num_channels]) - - -# Adapted from transformers.vit.modeling_tf_vit.TFViTEmbeddings -class TFGroupViTVisionEmbeddings(keras.layers.Layer): - """ - Construct the position and patch embeddings. - - """ - - def __init__(self, config: GroupViTVisionConfig, **kwargs): - super().__init__(**kwargs) - - self.patch_embeddings = TFGroupViTPatchEmbeddings(config, name="patch_embeddings") - self.dropout = keras.layers.Dropout(rate=config.dropout, name="dropout") - self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - self.config = config - - def build(self, input_shape=None): - num_patches = self.patch_embeddings.num_patches - self.position_embeddings = self.add_weight( - shape=(1, num_patches, self.config.hidden_size), - initializer="zeros", - trainable=True, - name="position_embeddings", - ) - - if self.built: - return - self.built = True - if getattr(self, "patch_embeddings", None) is not None: - with tf.name_scope(self.patch_embeddings.name): - self.patch_embeddings.build(None) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, self.config.hidden_size]) - - def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. - - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 - """ - - batch_size, num_patches, dim = shape_list(embeddings) - num_positions = shape_list(self.position_embeddings)[1] - - if num_patches == num_positions and height == width: - return self.position_embeddings - patch_pos_embed = self.position_embeddings - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - patch_pos_embed = tf.image.resize( - images=tf.reshape( - patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) - ), - size=(h0, w0), - method="bicubic", - ) - patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) - return patch_pos_embed - - def call( - self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False - ) -> tf.Tensor: - _, _, height, width = shape_list(pixel_values) - embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - embeddings = self.layernorm(embeddings) - - # add positional encoding to each token - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - embeddings = embeddings + self.position_embeddings - - embeddings = self.dropout(embeddings) - - return embeddings - - -# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextEmbeddings with CLIP->GroupViT -class TFGroupViTTextEmbeddings(keras.layers.Layer): - def __init__(self, config: GroupViTTextConfig, **kwargs): - super().__init__(**kwargs) - - self.embed_dim = config.hidden_size - - self.config = config - - def build(self, input_shape: tf.TensorShape = None): - with tf.name_scope("token_embedding"): - self.weight = self.add_weight( - shape=(self.config.vocab_size, self.embed_dim), - initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range), - trainable=True, - name="weight", - ) - - with tf.name_scope("position_embedding"): - self.position_embedding = self.add_weight( - shape=(self.config.max_position_embeddings, self.embed_dim), - initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range), - trainable=True, - name="embeddings", - ) - - super().build(input_shape) - - def call( - self, - input_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - if input_ids is None and inputs_embeds is None: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) - - position_embeds = tf.gather(params=self.position_embedding, indices=position_ids) - position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) - final_embeddings = inputs_embeds + position_embeds - - return final_embeddings - - -class TFGroupViTStage(keras.layers.Layer): - """This corresponds to the `GroupingLayer` class in the GroupViT implementation.""" - - def __init__( - self, - config: GroupViTVisionConfig, - depth: int, - num_prev_group_token: int, - num_group_token: int, - num_output_group: int, - **kwargs, - ): - super().__init__(**kwargs) - self.config = config - self.depth = depth - self.num_group_token = num_group_token - self.layers = [TFGroupViTEncoderLayer(config, name=f"layers_._{i}") for i in range(depth)] - - if num_group_token > 0: - self.downsample = TFGroupViTTokenAssign( - config=config, - num_group_token=num_group_token, - num_output_group=num_output_group, - name="downsample", - ) - else: - self.downsample = None - - if num_prev_group_token > 0 and num_group_token > 0: - self.group_projector = [ - keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="group_projector.0"), - TFGroupViTMixerMLP( - config, num_prev_group_token, config.hidden_size // 2, num_group_token, name="group_projector.1" - ), - ] - else: - self.group_projector = None - - def build(self, input_shape=None): - if self.num_group_token > 0: - self.group_token = self.add_weight( - shape=(1, self.num_group_token, self.config.hidden_size), - initializer="zeros", - trainable=True, - name="group_token", - ) - else: - self.group_token = None - - if self.built: - return - self.built = True - if getattr(self, "downsample", None) is not None: - with tf.name_scope(self.downsample.name): - self.downsample.build(None) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - if getattr(self, "group_projector", None) is not None: - with tf.name_scope(self.group_projector[0].name): - self.group_projector[0].build([None, None, self.config.hidden_size]) - with tf.name_scope(self.group_projector[1].name): - self.group_projector[1].build(None) - - @property - def with_group_token(self): - return self.group_token is not None - - def split_x(self, x: tf.Tensor) -> tf.Tensor: - if self.with_group_token: - return x[:, : -self.num_group_token], x[:, -self.num_group_token :] - else: - return x, None - - def concat_x(self, x: tf.Tensor, group_token: tf.Tensor | None = None) -> tf.Tensor: - if group_token is None: - return x - return tf.concat([x, group_token], axis=1) - - def call( - self, - hidden_states: tf.Tensor, - prev_group_token: tf.Tensor | None = None, - output_attentions: bool = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - `(config.encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the grouping tensors of Grouping block. - """ - if self.with_group_token: - group_token = tf.tile(self.group_token, multiples=(shape_list(hidden_states)[0], 1, 1)) - if self.group_projector is not None: - for layer in self.group_projector: - prev_group_token = layer(prev_group_token) - group_token = group_token + prev_group_token - else: - group_token = None - - x = hidden_states - - cat_x = self.concat_x(x, group_token) - for layer in self.layers: - layer_out = layer( - cat_x, - attention_mask=None, - causal_attention_mask=None, - output_attentions=None, - ) - cat_x = layer_out[0] - - x, group_token = self.split_x(cat_x) - - attention = None - if self.downsample is not None: - x, attention = self.downsample(x, group_token) - - outputs = (x, group_token) - if output_attentions: - outputs = outputs + (attention,) - - return outputs - - -class TFGroupViTMLP(keras.layers.Layer): - def __init__( - self, - config: GroupViTVisionConfig, - hidden_size: int | None = None, - intermediate_size: int | None = None, - output_size: int | None = None, - **kwargs, - ): - super().__init__(**kwargs) - self.config = config - self.activation_fn = get_tf_activation(config.hidden_act) - hidden_size = hidden_size if hidden_size is not None else config.hidden_size - intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size - output_size = output_size if output_size is not None else hidden_size - self.fc1 = keras.layers.Dense(intermediate_size, name="fc1") - self.fc2 = keras.layers.Dense(output_size, name="fc2") - self.intermediate_size = intermediate_size - self.hidden_size = hidden_size - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.hidden_size]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.intermediate_size]) - - -class TFGroupViTMixerMLP(TFGroupViTMLP): - def call(self, x, training: bool = False): - x = super().call(hidden_states=tf.transpose(x, perm=(0, 2, 1))) - return tf.transpose(x, perm=(0, 2, 1)) - - -# Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPAttention -class TFGroupViTAttention(keras.layers.Layer): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: GroupViTConfig, **kwargs): - super().__init__(**kwargs) - - self.embed_dim = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = self.embed_dim // self.num_attention_heads - if self.attention_head_size * self.num_attention_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_attention_heads})." - ) - - factor = config.initializer_factor - in_proj_std = (self.embed_dim**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor - out_proj_std = (self.embed_dim**-0.5) * factor - - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.q_proj = keras.layers.Dense( - units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="q_proj" - ) - self.k_proj = keras.layers.Dense( - units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="k_proj" - ) - self.v_proj = keras.layers.Dense( - units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="v_proj" - ) - - self.dropout = keras.layers.Dropout(rate=config.attention_dropout) - - self.out_proj = keras.layers.Dense( - units=self.embed_dim, kernel_initializer=get_initializer(out_proj_std), name="out_proj" - ) - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention.transpose_for_scores - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - causal_attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - encoder_hidden_states: tf.Tensor | None = None, - training: bool = False, - ) -> tuple[tf.Tensor]: - """Input shape: Batch x Time x Channel""" - - batch_size = shape_list(hidden_states)[0] - is_cross_attention = encoder_hidden_states is not None - - mixed_query_layer = self.q_proj(inputs=hidden_states) - if is_cross_attention: - mixed_key_layer = self.k_proj(inputs=encoder_hidden_states) - mixed_value_layer = self.v_proj(inputs=encoder_hidden_states) - else: - mixed_key_layer = self.k_proj(inputs=hidden_states) - mixed_value_layer = self.v_proj(inputs=hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - # apply the causal_attention_mask first - if causal_attention_mask is not None: - # Apply the causal attention mask (precomputed for all layers in TFCLIPModel call() function) - attention_scores = tf.add(attention_scores, causal_attention_mask) - - if attention_mask is not None: - # Apply the attention mask (precomputed for all layers in TFCLIPModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - _attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=_attention_probs) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, embed_dim) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.embed_dim)) - - attention_output = self.out_proj(attention_output) - # In TFBert, attention weights are returned after dropout. - # However, in CLIP, they are returned before dropout. - outputs = (attention_output, _attention_probs) if output_attentions else (attention_output,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPEncoderLayer with CLIP->GroupViT -class TFGroupViTEncoderLayer(keras.layers.Layer): - def __init__(self, config: GroupViTConfig, **kwargs): - super().__init__(**kwargs) - - self.embed_dim = config.hidden_size - self.self_attn = TFGroupViTAttention(config, name="self_attn") - self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") - self.mlp = TFGroupViTMLP(config, name="mlp") - self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - causal_attention_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - causal_attention_mask (`tf.Tensor`): causal attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`): - Whether or not to return the attentions tensors of all attention layers. See `outputs` under returned - tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(inputs=hidden_states) - attention_outputs = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - training=training, - ) - hidden_states = attention_outputs[0] - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(inputs=hidden_states) - hidden_states = self.mlp(hidden_states=hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) + attention_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "layer_norm1", None) is not None: - with tf.name_scope(self.layer_norm1.name): - self.layer_norm1.build([None, None, self.embed_dim]) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "layer_norm2", None) is not None: - with tf.name_scope(self.layer_norm2.name): - self.layer_norm2.build([None, None, self.embed_dim]) - - -# Adapted from transformers.models.clip.modeling_tf_clip.TFGroupViTTextEncoder -class TFGroupViTTextEncoder(keras.layers.Layer): - def __init__(self, config: GroupViTTextConfig, **kwargs): - super().__init__(**kwargs) - - self.layers = [TFGroupViTEncoderLayer(config, name=f"layers_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states, - attention_mask: tf.Tensor, - causal_attention_mask: tf.Tensor, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> tuple | TFBaseModelOutput: - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFGroupViTVisionEncoder(keras.layers.Layer): - def __init__(self, config: GroupViTVisionConfig, **kwargs) -> None: - super().__init__(**kwargs) - - self.stages = [ - TFGroupViTStage( - config=config, - depth=config.depths[i], - num_group_token=config.num_group_tokens[i], - num_output_group=config.num_output_groups[i], - num_prev_group_token=config.num_output_groups[i - 1] if i > 0 else 0, - name=f"stages_._{i}", - ) - for i in range(len(config.depths)) - ] - - def call( - self, - hidden_states: tf.Tensor, - output_hidden_states: bool, - output_attentions: bool, - return_dict: bool, - training: bool = False, - ) -> tuple | TFBaseModelOutput: - all_hidden_states = () if output_hidden_states else None - all_groupings = () if output_attentions else None - - group_tokens = None - - for stage in self.stages: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = stage(hidden_states, group_tokens, output_attentions) - - hidden_states = layer_outputs[0] - group_tokens = layer_outputs[1] - - if output_attentions and layer_outputs[2] is not None: - all_groupings = all_groupings + (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_groupings] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_groupings - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "stages", None) is not None: - for layer in self.stages: - with tf.name_scope(layer.name): - layer.build(None) - - -# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextTransformer with CLIPText->GroupViTText, CLIPEncoder->GroupViTTextEncoder -class TFGroupViTTextTransformer(keras.layers.Layer): - def __init__(self, config: GroupViTTextConfig, **kwargs): - super().__init__(**kwargs) - - self.embeddings = TFGroupViTTextEmbeddings(config, name="embeddings") - self.encoder = TFGroupViTTextEncoder(config, name="encoder") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm") - - # For `pooled_output` computation - self.eos_token_id = config.eos_token_id - self.embed_dim = config.hidden_size - - def call( - self, - input_ids: TFModelInputType, - attention_mask: tf.Tensor, - position_ids: tf.Tensor, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - input_shape = shape_list(input_ids) - - embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids) - - batch_size, seq_length = input_shape - # CLIP's text model uses causal mask, prepare it here. - # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = self._build_causal_attention_mask(batch_size, seq_length, dtype=embedding_output.dtype) - - # check attention mask and invert - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask) - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - sequence_output = self.final_layer_norm(inputs=sequence_output) - - if self.eos_token_id == 2: - # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. - # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added - # ------------------------------------------------------------ - # text_embeds.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - pooled_output = tf.gather_nd( - params=sequence_output, - indices=tf.stack( - values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1 - ), - ) - else: - # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible) - pooled_output = tf.gather_nd( - params=sequence_output, - indices=tf.stack( - values=( - tf.range(input_shape[0], dtype=tf.int64), - tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1), - ), - axis=1, - ), - ) - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def _build_causal_attention_mask(self, batch_size, seq_length, dtype=tf.float32): - # It is possible with an unspecified sequence length for seq_length to be - # a runtime value, which is unsupported by tf.constant. Per the TensorFlow - # docs, tf.fill can handle runtime dynamic shapes: - # https://www.tensorflow.org/api_docs/python/tf/fill - diag = tf.cast(tf.fill((seq_length,), 0.0), dtype) - - # set an additive 2D attention mask with all places being masked - to_mask = tf.cast(tf.fill((seq_length, seq_length), -10000.0), dtype) - - # set diagonal & lower triangular parts to 0 (i.e. the places not to be masked) - # TIP: think the 2D matrix as the space of (query_seq, key_seq) - to_mask = tf.linalg.band_part(to_mask, 0, -1) - # to_mask = tf.linalg.band_part(to_mask, -1, 0) - to_mask = tf.linalg.set_diag(to_mask, diagonal=diag) - - return tf.broadcast_to(input=to_mask, shape=(batch_size, 1, seq_length, seq_length)) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -# Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPVisionTransformer -class TFGroupViTVisionTransformer(keras.layers.Layer): - def __init__(self, config: GroupViTVisionConfig, **kwargs): - super().__init__(**kwargs) - - self.embeddings = TFGroupViTVisionEmbeddings(config, name="embeddings") - self.encoder = TFGroupViTVisionEncoder(config, name="encoder") - self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - self.embed_dim = config.hidden_size - - def call( - self, - pixel_values: TFModelInputType, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> tuple | TFBaseModelOutputWithPooling: - embedding_output = self.embeddings(pixel_values) - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - - # normalize the last hidden state - last_hidden_state = self.layernorm(last_hidden_state) - pooled_output = tf.math.reduce_mean(last_hidden_state, axis=1) - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, self.embed_dim]) - - -@keras_serializable -# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextMainLayer with CLIP->GroupViT -class TFGroupViTTextMainLayer(keras.layers.Layer): - config_class = GroupViTTextConfig - - def __init__(self, config: GroupViTTextConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.text_model = TFGroupViTTextTransformer(config, name="text_model") - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.text_model.embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.text_model.embeddings.weight = value - self.text_model.embeddings.vocab_size = shape_list(value)[0] - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - if input_ids is None: - raise ValueError("You have to specify input_ids") - - input_shape = shape_list(input_ids) - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - - text_model_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return text_model_outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "text_model", None) is not None: - with tf.name_scope(self.text_model.name): - self.text_model.build(None) - - -@keras_serializable -# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPVisionMainLayer with CLIP->GroupViT -class TFGroupViTVisionMainLayer(keras.layers.Layer): - config_class = GroupViTVisionConfig - - def __init__(self, config: GroupViTVisionConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.vision_model = TFGroupViTVisionTransformer(config, name="vision_model") - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.vision_model.embeddings - - @unpack_inputs - def call( - self, - pixel_values: TFModelInputType | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - vision_model_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return vision_model_outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "vision_model", None) is not None: - with tf.name_scope(self.vision_model.name): - self.vision_model.build(None) - - -@keras_serializable -# Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPMainLayer -class TFGroupViTMainLayer(keras.layers.Layer): - config_class = GroupViTConfig - - def __init__(self, config: GroupViTConfig, **kwargs): - super().__init__(**kwargs) - - if not isinstance(config.text_config, GroupViTTextConfig): - raise TypeError( - "config.text_config is expected to be of type GroupViTTextConfig but is of type" - f" {type(config.text_config)}." - ) - - if not isinstance(config.vision_config, GroupViTVisionConfig): - raise TypeError( - "config.vision_config is expected to be of type GroupViTVisionConfig but is of type" - f" {type(config.vision_config)}." - ) - - self.config = config - - text_config = config.text_config - vision_config = config.vision_config - - self.projection_dim = config.projection_dim - self.projection_intermediate_dim = config.projection_intermediate_dim - self.text_embed_dim = text_config.hidden_size - self.vision_embed_dim = vision_config.hidden_size - - self.text_model = TFGroupViTTextTransformer(text_config, name="text_model") - self.vision_model = TFGroupViTVisionTransformer(vision_config, name="vision_model") - - self.visual_projection = [ - keras.layers.Dense(self.projection_intermediate_dim, name="visual_projection.0"), - keras.layers.BatchNormalization(name="visual_projection.1", momentum=0.9, epsilon=1e-5), - keras.layers.ReLU(name="visual_projection.2"), - keras.layers.Dense(self.projection_dim, name="visual_projection.3"), - ] - self.text_projection = [ - keras.layers.Dense(self.projection_intermediate_dim, name="text_projection.0"), - keras.layers.BatchNormalization(name="text_projection.1", momentum=0.9, epsilon=1e-5), - keras.layers.ReLU(name="text_projection.2"), - keras.layers.Dense(self.projection_dim, name="text_projection.3"), - ] - - def build(self, input_shape=None): - self.logit_scale = self.add_weight( - shape=(1,), - initializer=keras.initializers.Constant(self.config.logit_scale_init_value), - trainable=True, - name="logit_scale", - ) - - if self.built: - return - self.built = True - if getattr(self, "text_model", None) is not None: - with tf.name_scope(self.text_model.name): - self.text_model.build(None) - if getattr(self, "vision_model", None) is not None: - with tf.name_scope(self.vision_model.name): - self.vision_model.build(None) - if getattr(self, "visual_projection", None) is not None: - with tf.name_scope(self.visual_projection[0].name): - self.visual_projection[0].build([None, None, None, self.vision_embed_dim]) - with tf.name_scope(self.visual_projection[1].name): - self.visual_projection[1].build((None, self.projection_intermediate_dim)) - with tf.name_scope(self.visual_projection[3].name): - self.visual_projection[3].build([None, None, None, self.projection_intermediate_dim]) - if getattr(self, "text_projection", None) is not None: - with tf.name_scope(self.text_projection[0].name): - self.text_projection[0].build([None, None, None, self.text_embed_dim]) - with tf.name_scope(self.text_projection[1].name): - self.text_projection[1].build((None, self.projection_intermediate_dim)) - with tf.name_scope(self.text_projection[3].name): - self.text_projection[3].build([None, None, None, self.projection_intermediate_dim]) - - @unpack_inputs - def get_text_features( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tf.Tensor: - if input_ids is None: - raise ValueError("You have to specify either input_ids") - - input_shape = shape_list(input_ids) - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - - text_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - pooled_output = text_outputs[1] - for layer in self.text_projection: - pooled_output = layer(pooled_output) - - text_features = pooled_output - return text_features - - @unpack_inputs - def get_image_features( - self, - pixel_values: TFModelInputType | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tf.Tensor: - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - pooled_output = vision_outputs[1] - for layer in self.visual_projection: - pooled_output = layer(pooled_output) - - image_features = pooled_output - return image_features - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - pixel_values: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - return_loss: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - output_segmentation: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFGroupViTModelOutput | tuple[tf.Tensor]: - if input_ids is None: - raise ValueError("You have to specify either input_ids") - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - input_shape = shape_list(input_ids) - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - if output_segmentation: - output_attentions = True - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - text_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - image_embeds = vision_outputs[1] - for layer in self.visual_projection: - image_embeds = layer(image_embeds) - - text_embeds = text_outputs[1] - for layer in self.text_projection: - text_embeds = layer(text_embeds) - - # normalized features - image_embeds = image_embeds / tf.norm(image_embeds, axis=-1, keepdims=True) - text_embeds = text_embeds / tf.norm(text_embeds, axis=-1, keepdims=True) - - # cosine similarity as logits - logit_scale = tf.math.exp(self.logit_scale) - logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale - logits_per_image = tf.transpose(logits_per_text) - - seg_logits = None - if output_segmentation: - # grouped features - # [batch_size_image, num_group, hidden_size] - image_group_embeds = vision_outputs[0] - # [batch_size_image*num_group, hidden_size] - image_group_embeds = tf.reshape(image_group_embeds, shape=(-1, shape_list(image_group_embeds)[-1])) - for layer in self.visual_projection: - image_group_embeds = layer(image_group_embeds) - if output_hidden_states: - attentions = vision_outputs[3] - else: - attentions = vision_outputs[2] - # [batch_size_image, num_group, height, width] - grouping = get_grouping_from_attentions(attentions, pixel_values.shape[2:]) - - # normalized features - image_group_embeds = image_group_embeds / tf.norm( - tensor=image_group_embeds, ord="euclidean", axis=-1, keepdims=True - ) - # [batch_size_image x num_group, batch_size_text] - logits_per_image_group = tf.matmul(image_group_embeds, text_embeds, transpose_b=True) * logit_scale - # [batch_size_image, batch_size_text, num_group] - logits_per_image_group = tf.reshape( - logits_per_image_group, shape=(image_embeds.shape[0], -1, text_embeds.shape[0]) - ) - logits_per_image_group = tf.transpose(logits_per_image_group, perm=(0, 2, 1)) - - # [batch_size_image, batch_size_text, height x width] - flatten_grouping = tf.reshape(grouping, shape=(shape_list(grouping)[0], shape_list(grouping)[1], -1)) - - # [batch_size_image, batch_size_text, height, width] - seg_logits = tf.matmul(logits_per_image_group, flatten_grouping) * logit_scale - seg_logits = tf.reshape( - seg_logits, shape=(seg_logits.shape[0], seg_logits.shape[1], grouping.shape[2], grouping.shape[3]) - ) - - loss = None - if return_loss: - loss = groupvit_loss(logits_per_text)[None, ...] - - if not return_dict: - if seg_logits is not None: - output = ( - logits_per_image, - logits_per_text, - seg_logits, - text_embeds, - image_embeds, - text_outputs, - vision_outputs, - ) - else: - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output - - return TFGroupViTModelOutput( - loss=loss, - logits_per_image=logits_per_image, - logits_per_text=logits_per_text, - segmentation_logits=seg_logits, - text_embeds=text_embeds, - image_embeds=image_embeds, - text_model_output=text_outputs, - vision_model_output=vision_outputs, - ) - - -class TFGroupViTPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = GroupViTConfig - base_model_prefix = "groupvit" - - -GROUPVIT_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TF 2.0 models accepts two formats as inputs: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional arguments. - - This second option is useful when using [`keras.Model.fit`] method which currently requires having all the - tensors in the first argument of the model call function: `model(inputs)`. - - If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the - first positional argument : - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - - - Args: - config ([`GroupViTConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -GROUPVIT_TEXT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - -GROUPVIT_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]`, `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`CLIPImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - -GROUPVIT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`CLIPImageProcessor.__call__`] for details. - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - return_loss (`bool`, *optional*): - Whether or not to return the contrastive loss. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -class TFGroupViTTextModel(TFGroupViTPreTrainedModel): - config_class = GroupViTTextConfig - main_input_name = "input_ids" - - def __init__(self, config: GroupViTTextConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.groupvit = TFGroupViTTextMainLayer(config, name="groupvit") - - @unpack_inputs - @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=GroupViTTextConfig) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import CLIPTokenizer, TFGroupViTTextModel - - >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc") - >>> model = TFGroupViTTextModel.from_pretrained("nvidia/groupvit-gcc-yfcc") - - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled (EOS token) states - ```""" - - outputs = self.groupvit( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "groupvit", None) is not None: - with tf.name_scope(self.groupvit.name): - self.groupvit.build(None) - - -class TFGroupViTVisionModel(TFGroupViTPreTrainedModel): - config_class = GroupViTVisionConfig - main_input_name = "pixel_values" - - def __init__(self, config: GroupViTVisionConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.groupvit = TFGroupViTVisionMainLayer(config, name="groupvit") - - @unpack_inputs - @add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=GroupViTVisionConfig) - def call( - self, - pixel_values: TFModelInputType | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFGroupViTVisionModel - - >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc") - >>> model = TFGroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="tf") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled CLS states - ```""" - - outputs = self.groupvit( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "groupvit", None) is not None: - with tf.name_scope(self.groupvit.name): - self.groupvit.build(None) - - -@add_start_docstrings(GROUPVIT_START_DOCSTRING) -class TFGroupViTModel(TFGroupViTPreTrainedModel): - config_class = GroupViTConfig - - def __init__(self, config: GroupViTConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.groupvit = TFGroupViTMainLayer(config, name="groupvit") - - @unpack_inputs - @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def get_text_features( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tf.Tensor: - r""" - Returns: - text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying - the projection layer to the pooled output of [`TFGroupViTTextModel`]. - - Examples: - - ```python - >>> from transformers import CLIPTokenizer, TFGroupViTModel - - >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc") - >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc") - - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf") - >>> text_features = model.get_text_features(**inputs) - ```""" - - text_features = self.groupvit.get_text_features( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return text_features - - @unpack_inputs - @add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING) - def get_image_features( - self, - pixel_values: TFModelInputType | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tf.Tensor: - r""" - Returns: - image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying - the projection layer to the pooled output of [`TFGroupViTVisionModel`]. - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFGroupViTModel - - >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc") - >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="tf") - - >>> image_features = model.get_image_features(**inputs) - ```""" - - image_features = self.groupvit.get_image_features( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return image_features - - @unpack_inputs - @add_start_docstrings_to_model_forward(GROUPVIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFGroupViTModelOutput, config_class=GroupViTConfig) - def call( - self, - input_ids: TFModelInputType | None = None, - pixel_values: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - return_loss: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - output_segmentation: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFGroupViTModelOutput | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, TFGroupViTModel - >>> import tensorflow as tf - - >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc") - >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor( - ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="tf", padding=True - ... ) - - >>> outputs = model(**inputs) - >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score - >>> probs = tf.math.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities - ```""" - - outputs = self.groupvit( - input_ids=input_ids, - pixel_values=pixel_values, - attention_mask=attention_mask, - position_ids=position_ids, - return_loss=return_loss, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_segmentation=output_segmentation, - return_dict=return_dict, - training=training, - ) - - return outputs - - def serving_output(self, output: TFGroupViTModelOutput) -> TFGroupViTModelOutput: - # TODO: As is this currently fails with saved_model=True, because - # TensorFlow cannot trace through nested dataclasses. Reference: - # https://github.com/huggingface/transformers/pull/16886 - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "groupvit", None) is not None: - with tf.name_scope(self.groupvit.name): - self.groupvit.build(None) - - -__all__ = ["TFGroupViTModel", "TFGroupViTPreTrainedModel", "TFGroupViTTextModel", "TFGroupViTVisionModel"] diff --git a/src/transformers/models/hubert/modeling_tf_hubert.py b/src/transformers/models/hubert/modeling_tf_hubert.py deleted file mode 100644 index 45c05ff30737..000000000000 --- a/src/transformers/models/hubert/modeling_tf_hubert.py +++ /dev/null @@ -1,1671 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow Hubert model.""" - -from __future__ import annotations - -import warnings -from typing import Any - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput -from ...modeling_tf_utils import ( - TFPreTrainedModel, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list, stable_softmax -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_hubert import HubertConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "HubertConfig" - - -LARGE_NEGATIVE = -1e8 - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._sample_without_replacement -def _sample_without_replacement(distribution, num_samples): - """ - Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see - https://github.com/tensorflow/tensorflow/issues/9260 for more info - """ - z = -tf.math.log(tf.random.uniform(shape_list(distribution), 0, 1)) - _, indices = tf.nn.top_k(distribution + z, num_samples) - return indices - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._scatter_values_on_batch_indices -def _scatter_values_on_batch_indices(values, batch_indices, output_shape): - """ - Scatter function as in PyTorch with indices in format (batch_dim, indices) - """ - indices_shape = shape_list(batch_indices) - # broadcast batch dim to indices_shape - broad_casted_batch_dims = tf.reshape( - tf.broadcast_to(tf.expand_dims(tf.range(indices_shape[0]), axis=-1), indices_shape), [1, -1] - ) - # transform batch_indices to pair_indices - pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0)) - # scatter values to pair indices - return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), output_shape) - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._compute_mask_indices -def _compute_mask_indices( - shape: tuple[int, int], - mask_prob: float, - mask_length: int, - min_masks: int = 0, -) -> tf.Tensor: - """ - Computes random mask spans for a given shape - - Args: - shape: the shape for which to compute masks. - should be of size 2 where first element is batch size and 2nd is timesteps - attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements - mask_prob: - probability for each token to be chosen as start of the span to be masked. this will be multiplied by - number of timesteps divided by length of mask span to mask approximately this percentage of all elements. - however due to overlaps, the actual number will be smaller (unless no_overlap is True) - mask_length: size of the mask - min_masks: minimum number of masked spans - - Adapted from [fairseq's - data_utils.py](https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376). - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - tf.debugging.assert_less( - mask_length, - sequence_length, - message=( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and" - f" `sequence_length`: {sequence_length}`" - ), - ) - - # compute number of masked spans in batch - num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,)) - num_masked_spans = tf.maximum(num_masked_spans, min_masks) - num_masked_spans = tf.cast(num_masked_spans, tf.int32) - - # make sure num masked indices <= sequence_length - num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans) - num_masked_spans = tf.squeeze(num_masked_spans) - - # SpecAugment mask to fill - spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32) - - # uniform distribution to sample from, make sure that offset samples are < sequence_length - uniform_dist = tf.ones((batch_size, sequence_length - (mask_length - 1))) - - # get random indices to mask - spec_aug_mask_idxs = _sample_without_replacement(uniform_dist, num_masked_spans) - - # expand masked indices to masked spans - spec_aug_mask_idxs = tf.expand_dims(spec_aug_mask_idxs, -1) - spec_aug_mask_idxs = tf.tile(spec_aug_mask_idxs, (1, 1, mask_length)) - spec_aug_mask_idxs = tf.reshape(spec_aug_mask_idxs, (batch_size, num_masked_spans * mask_length)) - - offsets = tf.range(mask_length)[tf.newaxis, tf.newaxis, :] - offsets = tf.tile(offsets, (batch_size, num_masked_spans, 1)) - offsets = tf.reshape(offsets, (batch_size, num_masked_spans * mask_length)) - - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # scatter indices to mask - spec_aug_mask = _scatter_values_on_batch_indices( - tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask) - ) - - return spec_aug_mask - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2GroupNorm with Wav2Vec2->Hubert -class TFHubertGroupNorm(keras.layers.Layer): - """ - From tensorflow-addons https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization - """ - - def __init__( - self, - groups: int = 32, - axis: int = -1, - epsilon: float = 1e-3, - center: bool = True, - scale: bool = True, - beta_initializer: keras.initializers.Initializer = "zeros", - gamma_initializer: keras.initializers.Initializer = "ones", - beta_regularizer: keras.regularizers.Regularizer = None, - gamma_regularizer: keras.regularizers.Regularizer = None, - beta_constraint: keras.constraints.Constraint = None, - gamma_constraint: keras.constraints.Constraint = None, - **kwargs, - ): - super().__init__(**kwargs) - self.supports_masking = True - self.groups = groups - self.axis = axis - self.epsilon = epsilon - self.center = center - self.scale = scale - self.beta_initializer = keras.initializers.get(beta_initializer) - self.gamma_initializer = keras.initializers.get(gamma_initializer) - self.beta_regularizer = keras.regularizers.get(beta_regularizer) - self.gamma_regularizer = keras.regularizers.get(gamma_regularizer) - self.beta_constraint = keras.constraints.get(beta_constraint) - self.gamma_constraint = keras.constraints.get(gamma_constraint) - self._check_axis() - - def build(self, input_shape): - self._check_if_input_shape_is_none(input_shape) - self._set_number_of_groups_for_instance_norm(input_shape) - self._check_size_of_dimensions(input_shape) - self._create_input_spec(input_shape) - - self._add_gamma_weight(input_shape) - self._add_beta_weight(input_shape) - self.built = True - super().build(input_shape) - - def call(self, inputs): - input_shape = keras.backend.int_shape(inputs) - tensor_input_shape = tf.shape(inputs) - - reshaped_inputs, group_shape = self._reshape_into_groups(inputs, input_shape, tensor_input_shape) - - normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape) - - is_instance_norm = (input_shape[self.axis] // self.groups) == 1 - if not is_instance_norm: - outputs = tf.reshape(normalized_inputs, tensor_input_shape) - else: - outputs = normalized_inputs - - return outputs - - def get_config(self): - config = { - "groups": self.groups, - "axis": self.axis, - "epsilon": self.epsilon, - "center": self.center, - "scale": self.scale, - "beta_initializer": keras.initializers.serialize(self.beta_initializer), - "gamma_initializer": keras.initializers.serialize(self.gamma_initializer), - "beta_regularizer": keras.regularizers.serialize(self.beta_regularizer), - "gamma_regularizer": keras.regularizers.serialize(self.gamma_regularizer), - "beta_constraint": keras.constraints.serialize(self.beta_constraint), - "gamma_constraint": keras.constraints.serialize(self.gamma_constraint), - } - base_config = super().get_config() - return {**base_config, **config} - - def compute_output_shape(self, input_shape): - return input_shape - - def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape): - group_shape = [tensor_input_shape[i] for i in range(len(input_shape))] - is_instance_norm = (input_shape[self.axis] // self.groups) == 1 - if not is_instance_norm: - group_shape[self.axis] = input_shape[self.axis] // self.groups - group_shape.insert(self.axis, self.groups) - group_shape = tf.stack(group_shape) - reshaped_inputs = tf.reshape(inputs, group_shape) - return reshaped_inputs, group_shape - else: - return inputs, group_shape - - def _apply_normalization(self, reshaped_inputs, input_shape): - group_shape = keras.backend.int_shape(reshaped_inputs) - group_reduction_axes = list(range(1, len(group_shape))) - is_instance_norm = (input_shape[self.axis] // self.groups) == 1 - if not is_instance_norm: - axis = -2 if self.axis == -1 else self.axis - 1 - else: - axis = -1 if self.axis == -1 else self.axis - 1 - group_reduction_axes.pop(axis) - - mean, variance = tf.nn.moments(reshaped_inputs, group_reduction_axes, keepdims=True) - - gamma, beta = self._get_reshaped_weights(input_shape) - normalized_inputs = tf.nn.batch_normalization( - reshaped_inputs, - mean=mean, - variance=variance, - scale=gamma, - offset=beta, - variance_epsilon=self.epsilon, - ) - return normalized_inputs - - def _get_reshaped_weights(self, input_shape): - broadcast_shape = self._create_broadcast_shape(input_shape) - gamma = None - beta = None - if self.scale: - gamma = tf.reshape(self.gamma, broadcast_shape) - - if self.center: - beta = tf.reshape(self.beta, broadcast_shape) - return gamma, beta - - def _check_if_input_shape_is_none(self, input_shape): - dim = input_shape[self.axis] - if dim is None: - raise ValueError( - "Axis " - + str(self.axis) - + " of input tensor should have a defined dimension but the layer received an input with shape " - + str(input_shape) - + "." - ) - - def _set_number_of_groups_for_instance_norm(self, input_shape): - dim = input_shape[self.axis] - - if self.groups == -1: - self.groups = dim - - def _check_size_of_dimensions(self, input_shape): - dim = input_shape[self.axis] - if dim < self.groups: - raise ValueError( - "Number of groups (" - + str(self.groups) - + ") cannot be more than the number of channels (" - + str(dim) - + ")." - ) - - if dim % self.groups != 0: - raise ValueError( - "Number of groups (" - + str(self.groups) - + ") must be a multiple of the number of channels (" - + str(dim) - + ")." - ) - - def _check_axis(self): - if self.axis == 0: - raise ValueError( - "You are trying to normalize your batch axis. Do you want to use tf.layer.batch_normalization instead" - ) - - def _create_input_spec(self, input_shape): - dim = input_shape[self.axis] - self.input_spec = keras.layers.InputSpec(ndim=len(input_shape), axes={self.axis: dim}) - - def _add_gamma_weight(self, input_shape): - dim = input_shape[self.axis] - shape = (dim,) - - if self.scale: - self.gamma = self.add_weight( - shape=shape, - name="gamma", - initializer=self.gamma_initializer, - regularizer=self.gamma_regularizer, - constraint=self.gamma_constraint, - ) - else: - self.gamma = None - - def _add_beta_weight(self, input_shape): - dim = input_shape[self.axis] - shape = (dim,) - - if self.center: - self.beta = self.add_weight( - shape=shape, - name="beta", - initializer=self.beta_initializer, - regularizer=self.beta_regularizer, - constraint=self.beta_constraint, - ) - else: - self.beta = None - - def _create_broadcast_shape(self, input_shape): - broadcast_shape = [1] * len(input_shape) - is_instance_norm = (input_shape[self.axis] // self.groups) == 1 - if not is_instance_norm: - broadcast_shape[self.axis] = input_shape[self.axis] // self.groups - broadcast_shape.insert(self.axis, self.groups) - else: - broadcast_shape[self.axis] = self.groups - return broadcast_shape - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2WeightNormConv1D with Wav2Vec2->Hubert -class TFHubertWeightNormConv1D(keras.layers.Conv1D): - """Adapted from https://www.tensorflow.org/probability/api_docs/python/tfp/layers/weight_norm/WeightNorm""" - - def __init__(self, filters, kernel_size, groups, explicit_padding, **kwargs): - super().__init__( - filters=filters, - kernel_size=kernel_size, - groups=groups, - padding="valid", - use_bias=True, - bias_initializer="he_normal", - **kwargs, - ) - self.explicit_padding = explicit_padding - self.filter_axis = 2 - self.kernel_norm_axes = tf.constant([0, 1]) - - def _init_norm(self): - """Set the norm of the weight vector.""" - kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.weight_v), axis=self.kernel_norm_axes)) - self.weight_g.assign(kernel_norm[:, tf.newaxis, tf.newaxis]) - - def _normalize_kernel(self): - """Generate normalized weights.""" - kernel = tf.nn.l2_normalize(self.weight_v, axis=self.kernel_norm_axes) * tf.transpose(self.weight_g) - self.kernel = tf.transpose(kernel) - - def build(self, input_shape): - if not self.built: - super().build(input_shape) - - self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True) - self.weight_v = self.kernel - - self.weight_g = self.add_weight( - name="weight_g", - shape=(int(self.weight_v.shape[self.filter_axis]), 1, 1), - initializer="ones", - dtype=self.weight_v.dtype, - trainable=True, - ) - self._init_norm() - self.bias = self.add_weight(name="bias", shape=(self.filters,), initializer="zeros", trainable=True) - - def call(self, inputs): - # TODO Matt: Assigning to attributes in call() is deeply sinful in TensorFlow, as it should be idempotent. - # This whole layer should be replaced by a layer that doesn't inherit from Conv1D, but instead calls - # a functional 1d convolution with normalized weights that it generates (but does not store!) - self._normalize_kernel() - - padded_inputs = tf.pad(inputs, ((0, 0), (self.explicit_padding, self.explicit_padding), (0, 0))) - output = super().call(padded_inputs) - - return output - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2NoLayerNormConvLayer with Wav2Vec2->Hubert -class TFHubertNoLayerNormConvLayer(keras.layers.Layer): - def __init__(self, config: HubertConfig, layer_id: int = 0, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = keras.layers.Conv1D( - filters=self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - strides=config.conv_stride[layer_id], - use_bias=config.conv_bias, - name="conv", - ) - self.activation = get_tf_activation(config.feat_extract_activation) - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.conv(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build([None, None, self.in_conv_dim]) - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2LayerNormConvLayer with Wav2Vec2->Hubert -class TFHubertLayerNormConvLayer(keras.layers.Layer): - def __init__(self, config: HubertConfig, layer_id: int = 0, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = keras.layers.Conv1D( - filters=self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - strides=config.conv_stride[layer_id], - use_bias=config.conv_bias, - name="conv", - ) - self.layer_norm = keras.layers.LayerNormalization(name="layer_norm", epsilon=config.layer_norm_eps) - self.activation = get_tf_activation(config.feat_extract_activation) - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.conv(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build([None, None, self.in_conv_dim]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.out_conv_dim]) - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2GroupNormConvLayer with Wav2Vec2->Hubert -class TFHubertGroupNormConvLayer(keras.layers.Layer): - def __init__(self, config: HubertConfig, layer_id: int = 0, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = keras.layers.Conv1D( - filters=self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - strides=config.conv_stride[layer_id], - use_bias=config.conv_bias, - name="conv", - ) - self.activation = get_tf_activation(config.feat_extract_activation) - self.layer_norm = TFHubertGroupNorm(groups=self.out_conv_dim, epsilon=config.layer_norm_eps, name="layer_norm") - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.conv(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build([None, None, self.in_conv_dim]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.out_conv_dim]) - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert -class TFHubertPositionalConvEmbedding(keras.layers.Layer): - def __init__(self, config: HubertConfig, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.conv = TFHubertWeightNormConv1D( - filters=config.hidden_size, - kernel_size=config.num_conv_pos_embeddings, - groups=config.num_conv_pos_embedding_groups, - explicit_padding=config.num_conv_pos_embeddings // 2, - name="conv", - ) - self.padding = TFHubertSamePadLayer(config.num_conv_pos_embeddings) - self.activation = get_tf_activation(config.feat_extract_activation) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.conv(hidden_states) - hidden_states = self.padding(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2SamePadLayer with Wav2Vec2->Hubert -class TFHubertSamePadLayer(keras.layers.Layer): - def __init__(self, num_conv_pos_embeddings, **kwargs): - super().__init__(**kwargs) - self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 - - def call(self, hidden_states): - if self.num_pad_remove > 0: - hidden_states = hidden_states[:, : -self.num_pad_remove, :] - return hidden_states - - -class TFHubertFeatureEncoder(keras.layers.Layer): - def __init__(self, config: HubertConfig, **kwargs: Any) -> None: - super().__init__(**kwargs) - - if config.feat_extract_norm == "group": - conv_layers = [TFHubertGroupNormConvLayer(config, layer_id=0, name=f"conv_layers.{0}")] + [ - TFHubertNoLayerNormConvLayer(config, layer_id=i + 1, name=f"conv_layers.{i + 1}") - for i in range(config.num_feat_extract_layers - 1) - ] - elif config.feat_extract_norm == "layer": - conv_layers = [ - TFHubertLayerNormConvLayer(config, layer_id=i, name=f"conv_layers.{i}") - for i in range(config.num_feat_extract_layers) - ] - else: - raise ValueError( - f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" - ) - self.conv_layers = conv_layers - - def call(self, input_values): - hidden_states = tf.expand_dims(input_values, -1) - for conv_layer in self.conv_layers: - hidden_states = conv_layer(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - for conv_layer in self.conv_layers: - with tf.name_scope(conv_layer.name): - conv_layer.build(None) - - -class TFHubertFeatureExtractor(TFHubertFeatureEncoder): - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - warnings.warn( - f"The class `{self.__class__.__name__}` has been depreciated " - "and will be removed in Transformers v5. " - f"Use `{self.__class__.__bases__[0].__name__}` instead.", - FutureWarning, - ) - - -class TFHubertFeatureProjection(keras.layers.Layer): - def __init__(self, config: HubertConfig, **kwargs): - super().__init__(**kwargs) - - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.projection = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer="zeros", - name="projection", - ) - self.dropout = keras.layers.Dropout(rate=config.feat_proj_dropout) - self.config = config - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.projection(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.conv_dim[-1]]) - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, self.config.conv_dim[-1]]) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFHubert -class TFHubertAttention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2FeedForward with Wav2Vec2->Hubert -class TFHubertFeedForward(keras.layers.Layer): - def __init__(self, config: HubertConfig, **kwargs): - super().__init__(**kwargs) - - self.intermediate_dropout = keras.layers.Dropout(config.activation_dropout) - - self.intermediate_dense = keras.layers.Dense( - units=config.intermediate_size, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer="zeros", - name="intermediate_dense", - ) - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - - self.output_dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer="zeros", - name="output_dense", - ) - self.output_dropout = keras.layers.Dropout(config.hidden_dropout) - self.config = config - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.intermediate_dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - hidden_states = self.intermediate_dropout(hidden_states, training=training) - - hidden_states = self.output_dense(hidden_states) - hidden_states = self.output_dropout(hidden_states, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "intermediate_dense", None) is not None: - with tf.name_scope(self.intermediate_dense.name): - self.intermediate_dense.build([None, None, self.config.hidden_size]) - if getattr(self, "output_dense", None) is not None: - with tf.name_scope(self.output_dense.name): - self.output_dense.build([None, None, self.config.intermediate_size]) - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderLayer with Wav2Vec2->Hubert -class TFHubertEncoderLayer(keras.layers.Layer): - def __init__(self, config: HubertConfig, **kwargs): - super().__init__(**kwargs) - self.attention = TFHubertAttention( - embed_dim=config.hidden_size, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout, - is_decoder=False, - name="attention", - ) - self.dropout = keras.layers.Dropout(config.hidden_dropout) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.feed_forward = TFHubertFeedForward(config, name="feed_forward") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - attn_residual = hidden_states - hidden_states, attn_weights, _ = self.attention( - hidden_states, attention_mask=attention_mask, training=training - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = attn_residual + hidden_states - - hidden_states = self.layer_norm(hidden_states) - hidden_states = hidden_states + self.feed_forward(hidden_states) - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - if getattr(self, "feed_forward", None) is not None: - with tf.name_scope(self.feed_forward.name): - self.feed_forward.build(None) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert -class TFHubertEncoderLayerStableLayerNorm(keras.layers.Layer): - def __init__(self, config: HubertConfig, **kwargs): - super().__init__(**kwargs) - self.attention = TFHubertAttention( - embed_dim=config.hidden_size, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout, - is_decoder=False, - name="attention", - ) - self.dropout = keras.layers.Dropout(config.hidden_dropout) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.feed_forward = TFHubertFeedForward(config, name="feed_forward") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - attn_residual = hidden_states - hidden_states = self.layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.attention( - hidden_states, attention_mask=attention_mask, training=training - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = attn_residual + hidden_states - hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - if getattr(self, "feed_forward", None) is not None: - with tf.name_scope(self.feed_forward.name): - self.feed_forward.build(None) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2Encoder with Wav2Vec2->Hubert -class TFHubertEncoder(keras.layers.Layer): - def __init__(self, config: HubertConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.pos_conv_embed = TFHubertPositionalConvEmbedding(config, name="pos_conv_embed") - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.dropout = keras.layers.Dropout(config.hidden_dropout) - self.layer = [TFHubertEncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = False, - output_hidden_states: bool | None = False, - return_dict: bool | None = True, - training: bool | None = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - if attention_mask is not None: - hidden_states = hidden_states * tf.expand_dims(attention_mask, -1) - attention_mask = _expand_mask(attention_mask) - else: - attention_mask = None - - position_embeddings = self.pos_conv_embed(hidden_states) - hidden_states = hidden_states + position_embeddings - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = np.random.uniform(0, 1) - if training and (dropout_probability < self.config.layerdrop): # skip the layer - continue - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "pos_conv_embed", None) is not None: - with tf.name_scope(self.pos_conv_embed.name): - self.pos_conv_embed.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderStableLayerNorm with Wav2Vec2->Hubert -class TFHubertEncoderStableLayerNorm(keras.layers.Layer): - def __init__(self, config: HubertConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.pos_conv_embed = TFHubertPositionalConvEmbedding(config, name="pos_conv_embed") - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.dropout = keras.layers.Dropout(config.hidden_dropout) - self.layer = [ - TFHubertEncoderLayerStableLayerNorm(config, name=f"layers.{i}") for i in range(config.num_hidden_layers) - ] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = False, - output_hidden_states: bool | None = False, - return_dict: bool | None = True, - training: bool | None = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - if attention_mask is not None: - hidden_states = hidden_states * tf.expand_dims(attention_mask, -1) - attention_mask = _expand_mask(attention_mask) - else: - attention_mask = None - - position_embeddings = self.pos_conv_embed(hidden_states) - hidden_states = hidden_states + position_embeddings - hidden_states = self.dropout(hidden_states, training=training) - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = np.random.uniform(0, 1) - if training and (dropout_probability < self.config.layerdrop): # skip the layer - continue - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - hidden_states = self.layer_norm(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "pos_conv_embed", None) is not None: - with tf.name_scope(self.pos_conv_embed.name): - self.pos_conv_embed.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFHubertMainLayer(keras.layers.Layer): - config_class = HubertConfig - - def __init__(self, config: HubertConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.feature_extractor = TFHubertFeatureEncoder(config, name="feature_extractor") - self.feature_projection = TFHubertFeatureProjection(config, name="feature_projection") - - if config.do_stable_layer_norm: - self.encoder = TFHubertEncoderStableLayerNorm(config, name="encoder") - else: - self.encoder = TFHubertEncoder(config, name="encoder") - - def build(self, input_shape=None): - self.masked_spec_embed = self.add_weight( - shape=(self.config.hidden_size,), initializer="uniform", trainable=True, name="masked_spec_embed" - ) - - if self.built: - return - self.built = True - if getattr(self, "feature_extractor", None) is not None: - with tf.name_scope(self.feature_extractor.name): - self.feature_extractor.build(None) - if getattr(self, "feature_projection", None) is not None: - with tf.name_scope(self.feature_projection.name): - self.feature_projection.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - - def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor): - """ - Computes the output length of the convolutional layers - """ - - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return (input_length - kernel_size) // stride + 1 - - for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): - input_lengths = _conv_out_length(input_lengths, kernel_size, stride) - - return input_lengths - - def _mask_hidden_states(self, hidden_states: tf.Tensor, mask_time_indices: tf.Tensor | None = None): - """ - Masks extracted features along time axis and/or along feature axis according to - [SpecAugment](https://huggingface.co/papers/1904.08779). - """ - batch_size, sequence_length, hidden_size = shape_list(hidden_states) - - # `config.apply_spec_augment` can set masking to False - if not getattr(self.config, "apply_spec_augment", True): - return hidden_states - - if mask_time_indices is not None: - # apply SpecAugment along time axis with given mask_time_indices - hidden_states = tf.where( - tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool), - self.masked_spec_embed[tf.newaxis, tf.newaxis, :], - hidden_states, - ) - - elif self.config.mask_time_prob > 0: - # generate indices & apply SpecAugment along time axis - mask_time_indices = _compute_mask_indices( - (batch_size, sequence_length), - mask_prob=self.config.mask_time_prob, - mask_length=self.config.mask_time_length, - min_masks=2, - ) - hidden_states = tf.where( - tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool), - self.masked_spec_embed[tf.newaxis, tf.newaxis, :], - hidden_states, - ) - - # apply SpecAugment along feature axis - if self.config.mask_feature_prob > 0: - mask_feature_indices = _compute_mask_indices( - (batch_size, hidden_size), - mask_prob=self.config.mask_feature_prob, - mask_length=self.config.mask_feature_length, - ) - hidden_states = tf.where(mask_feature_indices[:, tf.newaxis, :], hidden_states, 0) - - return hidden_states - - @unpack_inputs - def call( - self, - input_values: tf.Tensor, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: tf.Tensor | None = None, - output_hidden_states: tf.Tensor | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs: Any, - ): - hidden_states = self.feature_extractor(tf.cast(input_values, tf.float32), training=training) - - if attention_mask is not None: - # compute real output lengths according to convolution formula - output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1)) - - attention_mask = tf.sequence_mask( - output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype - ) - - hidden_states = self.feature_projection(hidden_states, training=training) - - mask_time_indices = kwargs.get("mask_time_indices") - if training: - hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) - - encoder_outputs = self.encoder( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = encoder_outputs[0] - - if not return_dict: - return (hidden_states,) + encoder_outputs[1:] - - return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class TFHubertPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = HubertConfig - base_model_prefix = "hubert" - main_input_name = "input_values" - - @property - def input_signature(self): - return { - "input_values": tf.TensorSpec((None, 16000), tf.float32, name="input_values"), - "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), - "token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"), - } - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - logger.warning( - f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish " - "to train/fine-tune this model, you need a GPU or a TPU" - ) - - -HUBERT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_values` only and nothing else: `model(input_values)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_values, attention_mask])` or `model([input_values, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_values": input_values, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`HubertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -HUBERT_INPUTS_DOCSTRING = r""" - Args: - input_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_values` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_values` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare TFHubert Model transformer outputting raw hidden-states without any specific head on top.", - HUBERT_START_DOCSTRING, -) -class TFHubertModel(TFHubertPreTrainedModel): - def __init__(self, config: HubertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.config = config - self.hubert = TFHubertMainLayer(config, name="hubert") - - @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) - @unpack_inputs - def call( - self, - input_values: tf.Tensor, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - """ - - Returns: - - Example: - - ```python - >>> from transformers import AutoProcessor, TFHubertModel - >>> from datasets import load_dataset - - >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") - >>> model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft") - - - >>> def map_to_array(example): - ... example["speech"] = example["audio"]["array"] - ... return example - - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.map(map_to_array) - - >>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1 - >>> hidden_states = model(input_values).last_hidden_state - ```""" - - output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states - output_attentions = output_attentions if output_attentions else self.config.output_attentions - return_dict = return_dict if return_dict else self.config.return_dict - - outputs = self.hubert( - input_values=input_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "hubert", None) is not None: - with tf.name_scope(self.hubert.name): - self.hubert.build(None) - - -@add_start_docstrings( - """TFHubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", - HUBERT_START_DOCSTRING, -) -class TFHubertForCTC(TFHubertPreTrainedModel): - def __init__(self, config: HubertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.hubert = TFHubertMainLayer(config, name="hubert") - self.dropout = keras.layers.Dropout(config.final_dropout) - self.lm_head = keras.layers.Dense(config.vocab_size, name="lm_head") - self.output_hidden_size = ( - config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size - ) - - def freeze_feature_extractor(self): - """ - Calling this function will disable the gradient computation for the feature encoder so that its parameters will - not be updated during training. - """ - warnings.warn( - "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " - "Please use the equivalent `freeze_feature_encoder` method instead.", - FutureWarning, - ) - self.freeze_feature_encoder() - - def freeze_feature_encoder(self): - """ - Calling this function will disable the gradient computation for the feature encoder so that its parameter will - not be updated during training. - """ - self.hubert.feature_extractor.trainable = False - - @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC) - @unpack_inputs - def call( - self, - input_values: tf.Tensor, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - labels: tf.Tensor | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFCausalLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_values` docstring) Tokens with indices set to `-100` are ignored (masked), - the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - - Returns: - - Example: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoProcessor, TFHubertForCTC - >>> from datasets import load_dataset - - >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") - >>> model = TFHubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft") - - - >>> def map_to_array(example): - ... example["speech"] = example["audio"]["array"] - ... return example - - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.map(map_to_array) - - >>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1 - >>> logits = model(input_values).logits - >>> predicted_ids = tf.argmax(logits, axis=-1) - - >>> transcription = processor.decode(predicted_ids[0]) - - >>> # compute loss - >>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST" - - >>> # Pass the transcription as text to encode labels - >>> labels = processor(text=transcription, return_tensors="tf").input_values - - >>> loss = model(input_values, labels=labels).loss - ```""" - if labels is not None and tf.reduce_max(labels) >= self.config.vocab_size: - raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") - - outputs = self.hubert( - input_values=input_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, training=training) - - logits = self.lm_head(hidden_states) - - if labels is not None: - attention_mask = ( - attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32) - ) - input_lengths = self.hubert._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1)) - - # assuming that padded tokens are filled with -100 - # when not being attended to - labels_mask = tf.cast(labels >= 0, tf.int32) - target_lengths = tf.reduce_sum(labels_mask, axis=-1) - - loss = tf.nn.ctc_loss( - logits=logits, - labels=labels, - logit_length=input_lengths, - label_length=target_lengths, - blank_index=self.config.pad_token_id, - logits_time_major=False, - ) - - if self.config.ctc_loss_reduction == "sum": - loss = tf.reduce_sum(loss) - loss = tf.reshape(loss, (1,)) - if self.config.ctc_loss_reduction == "mean": - loss = tf.reduce_mean(loss) - loss = tf.reshape(loss, (1,)) - else: - loss = None - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "hubert", None) is not None: - with tf.name_scope(self.hubert.name): - self.hubert.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build([None, None, self.output_hidden_size]) - - -__all__ = ["TFHubertForCTC", "TFHubertModel", "TFHubertPreTrainedModel"] diff --git a/src/transformers/models/idefics/modeling_tf_idefics.py b/src/transformers/models/idefics/modeling_tf_idefics.py deleted file mode 100644 index 0e8e75be28f8..000000000000 --- a/src/transformers/models/idefics/modeling_tf_idefics.py +++ /dev/null @@ -1,1778 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 Idefics model.""" - -from __future__ import annotations - -from dataclasses import dataclass - -import tensorflow as tf - -from ... import TFPreTrainedModel -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ModelOutput -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - keras_serializable, - shape_list, - unpack_inputs, -) -from ...tf_utils import invert_attention_mask, scaled_dot_product_attention -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_idefics import IdeficsConfig -from .perceiver_tf import TFIdeficsPerceiverResampler -from .vision_tf import TFIdeficsVisionTransformer - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "IdeficsConfig" - - -@dataclass -class TFIdeficsBaseModelOutputWithPast(ModelOutput): - """ - Base class for Idefics model's outputs that may also contain a past key/values (to speed up sequential decoding). - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`tuple(tf.Tensor)`, *optional*): - Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver - """ - - last_hidden_state: tf.Tensor | None = None - past_key_values: tuple[tuple[tf.Tensor]] | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - image_hidden_states: tuple[tf.Tensor] | None = None - - -@dataclass -class TFIdeficsCausalLMOutputWithPast(ModelOutput): - """ - Base class for Idefics causal language model (or autoregressive) outputs. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`tuple(tf.Tensor)`, *optional*): - Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - image_hidden_states: tuple[tf.Tensor] | None = None - - -def expand_inputs_for_generation( - input_ids, - expand_size=1, - is_encoder_decoder=False, - attention_mask=None, - encoder_outputs=None, - **model_kwargs, -): - expanded_return_idx = tf.reshape(tf.repeat(tf.range(tf.shape(input_ids)[0]), expand_size), [-1]) - input_ids = tf.gather(input_ids, expanded_return_idx) - model_kwargs["pixel_values"] = model_kwargs.get("pixel_values") - model_kwargs["image_encoder_embeddings"] = model_kwargs.get("image_encoder_embeddings") - model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings") - model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask") - - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx) - - if attention_mask is not None: - model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx) - - if model_kwargs["image_attention_mask"] is not None: - model_kwargs["image_attention_mask"] = tf.gather(model_kwargs["image_attention_mask"], expanded_return_idx) - - if model_kwargs["pixel_values"] is not None: - model_kwargs["pixel_values"] = tf.gather(model_kwargs["pixel_values"], expanded_return_idx) - - elif model_kwargs["image_encoder_embeddings"] is not None: - model_kwargs["image_encoder_embeddings"] = tf.gather( - model_kwargs["image_encoder_embeddings"], expanded_return_idx - ) - - elif model_kwargs["perceiver_embeddings"] is not None: - model_kwargs["perceiver_embeddings"] = tf.gather(model_kwargs["perceiver_embeddings"], expanded_return_idx) - - return input_ids, model_kwargs - - -def update_model_kwargs_for_generation(outputs, model_kwargs): - # must have this key set to at least None - if "past_key_values" in outputs: - model_kwargs["past_key_values"] = outputs.past_key_values - else: - model_kwargs["past_key_values"] = None - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = tf.concat([token_type_ids, token_type_ids[:, -1:, ...]], axis=-1) - - # update attention masks - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = tf.concat( - [attention_mask, tf.ones_like(attention_mask[:, -1:, ...])], axis=-1 - ) - if "image_attention_mask" in model_kwargs: - image_attention_mask = model_kwargs["image_attention_mask"] - last_mask = image_attention_mask[:, -1:, ...] - model_kwargs["image_attention_mask"] = last_mask - - # Get the precomputed image_hidden_states - model_kwargs["image_hidden_states"] = outputs.image_hidden_states - - return model_kwargs - - -def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids") - # only last token for inputs_ids if past is defined in kwargs - if past_key_values is not None: - input_ids = input_ids[:, -1:] - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1:] - - attention_mask = kwargs.get("attention_mask") - position_ids = kwargs.get("position_ids") - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = tf.math.cumsum(tf.cast(attention_mask, dtype=tf.int64), axis=-1) - 1 - position_ids = tf.where(attention_mask == 0, 1, position_ids) - if past_key_values is not None: - position_ids = position_ids[:, -1:] - - pixel_values = kwargs.get("pixel_values") - image_encoder_embeddings = kwargs.get("image_encoder_embeddings") - perceiver_embeddings = kwargs.get("perceiver_embeddings") - image_attention_mask = kwargs.get("image_attention_mask") - interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - "pixel_values": pixel_values, - "image_encoder_embeddings": image_encoder_embeddings, - "perceiver_embeddings": perceiver_embeddings, - "image_attention_mask": image_attention_mask, - "interpolate_pos_encoding": interpolate_pos_encoding, - } - - -def freeze_model(model, module_exceptions=[]): - mapping = { - "LayerNorm": tf.keras.layers.LayerNormalization, - "Dense": tf.keras.layers.Dense, - "Embedding": tf.keras.layers.Embedding, - } - module_exceptions_mapped = [mapping[m] for m in module_exceptions] - if not hasattr(model, "layers"): - model.trainable = False # It is just a layer - return model - for layer in model.layers: - if module_exceptions and any(isinstance(layer, t) for t in module_exceptions_mapped): - layer.trainable = True # Explicitly setting it to true to avoid any mistakes - else: - layer.trainable = False - return model - - -class TFIdeficsDecoupledEmbedding(tf.keras.layers.Embedding): - """ - Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the - regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, - then it will create `num_additional_embeddings` additional parameters that are always trained. If - `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Embedding`. - """ - - def __init__( - self, - num_embeddings, - num_additional_embeddings, - embedding_dim, - partially_freeze: bool | None = False, - dtype=None, - **kwargs, - ) -> None: - """ - Args: - num_embeddings (`int`): - Size of the dictionary of embeddings - num_additional_embeddings (`int`): - Number of additional embeddings. Only useful when you `partially_freeze=True`. - embedding_dim (`int`): - The size of each embedding vector - partially_freeze: (`bool`, *optional*, defaults to `False`): - If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen. - - Note: there are a lot of other parameters to initialize a standard `tf.keras.layers.Embedding` such as `mask_zero`, - `input_length` or `embeddings_initializer`. We are not supporting these. - """ - super().__init__( - input_dim=num_embeddings, - output_dim=embedding_dim, - dtype=dtype, - **kwargs, - ) - self.num_embeddings = num_embeddings - self.num_additional_embeddings = num_additional_embeddings - self.partially_freeze = partially_freeze - - if partially_freeze: - self.trainable = False - - if self.num_additional_embeddings > 0: - self.additional_embedding = tf.keras.layers.Embedding( - input_dim=self.num_additional_embeddings, - output_dim=embedding_dim, - dtype=dtype, - name="additional_embedding", - ) - - def call(self, input_ids): - """ - we have 2 embeddings, with different indices - one pretrained self.weight and another - self.additional_embedding.weight that is being trained. - - in order to make a lookup of the input ids, we: - 1. find out the indices of the entries belonging to the 2nd embedding - 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd - embedding starts from 0 and not num_embeddings - 3. perform the 2nd embedding lookup - 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index - 5. perform the 1st embedding lookup - 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup - - note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but - then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices - - i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are - usually relatively short it's probably not faster or if faster not by much - but might be a good idea to - measure. - - """ - if self.num_additional_embeddings == 0: - return super().call(input_ids) - - # Clone so that we don't modify the original input_ids later on - input_ids = tf.identity(input_ids) - additional_vocab_indices = tf.where(input_ids >= self.num_embeddings) - input_ids_additional_vocab = tf.gather_nd(input_ids, additional_vocab_indices) - additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) - - # for successful lookup replace input_ids with 0, the results of these will be discarded anyway - input_ids = tf.tensor_scatter_nd_update( - input_ids, - additional_vocab_indices, - # tensor filled with 0, having the same length as additional_vocab_indices - tf.zeros(tf.shape(additional_vocab_indices)[0], dtype=input_ids.dtype), - ) - full_vector = super().call(input_ids) - - # overwrite the records with high indices - full_vector = tf.tensor_scatter_nd_update(full_vector, additional_vocab_indices, additional_embeddings) - - return full_vector - - def extra_repr(self) -> str: - return f"num_embeddings={self.num_embeddings}, num_additional_embeddings={self.num_additional_embeddings}, embedding_dim={self.output_dim}, partially_freeze={self.partially_freeze}" - - -class TFIdeficsDecoupledLinear(tf.keras.layers.Layer): - """ - Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the - regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, - then it will create `out_additional_features * in_features` additional parameters that are always trained. If - `out_additional_features=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Dense`. - """ - - def __init__( - self, - in_features: int, - out_features: int, - out_additional_features: int = 0, - bias: bool = True, - partially_freeze: bool = True, - **kwargs, - ) -> None: - """ - out_additional_features: int. Number of additional trainable dimensions. Only makes sense when - `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra - parameters (if any) will be trainable. If False, default to the regular behavior of tf.keras.layers.Dense. - """ - super().__init__(**kwargs) - self.out_additional_features = out_additional_features - self.partially_freeze = partially_freeze - - self.in_features = in_features - self.out_features = out_features - self.use_bias = bias - - if out_additional_features > 0: - self.additional_fc = tf.keras.layers.Dense( - units=out_additional_features, use_bias=bias, name="additional_fc" - ) - - def call(self, inputs: tf.Tensor) -> tf.Tensor: - output = tf.linalg.matmul(a=inputs, b=self.weight, transpose_b=True) - if self.bias is not None: - output = tf.nn.bias_add(output, self.bias) - - if self.out_additional_features > 0: - additional_features = self.additional_fc(inputs) - output = tf.concat([output, additional_features], axis=-1) - - return output - - def get_config(self): - config = super().get_config() - config.update( - { - "in_features": self.in_features, - "out_features": self.out_features, - "out_additional_features": self.out_additional_features, - "bias": self.bias is not None, - "partially_freeze": self.partially_freeze, - } - ) - return config - - def extra_repr(self) -> str: - """Overwriting `nn.Linear.extra_repr` to include new parameters.""" - return f"in_features={self.in_features}, out_features={self.out_features}, out_additional_features={self.out_additional_features}, bias={self.bias is not None}, partially_freeze={self.partially_freeze}" - - @classmethod - def from_config(cls, config): - return cls(**config) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - self.weight = self.add_weight( - shape=(self.out_features, self.in_features), trainable=not self.partially_freeze, name="weight" - ) - if self.use_bias: - self.bias = self.add_weight(shape=(self.out_features,), trainable=not self.partially_freeze, name="bias") - else: - self.bias = None - if getattr(self, "additional_fc", None) is not None: - with tf.name_scope(self.additional_fc.name): - self.additional_fc.build(self.in_features) - - -def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0): - """ - Make causal mask used for bi-directional self-attention, supporting both static and dynamic shapes. - """ - bsz, tgt_len = input_ids_shape - - # Create a matrix where only the lower triangle and diagonal are filled with zeros (causal mask) - mask = tf.fill((tgt_len, tgt_len), tf.dtypes.as_dtype(dtype).min) - mask_cond = tf.range(tgt_len) - mask = tf.where(mask_cond[:, None] >= mask_cond[None, :], 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1) - - if bsz is None: - # When batch size is dynamic, expand and tile - # so we can compile a functional model - mask = tf.expand_dims(mask, 0) - mask = tf.expand_dims(mask, 0) # shape: (1, 1, tgt_len, tgt_len + past_key_values_length) - mask = tf.tile(mask, [bsz, 1, 1, 1]) - else: - # When batch size is static, directly use broadcast_to - mask = tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)) - - return mask - - -def _expand_mask(mask, dtype, tgt_len=None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = shape_list(mask) - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = tf.expand_dims(tf.expand_dims(mask, 1), 1) - expanded_mask = tf.broadcast_to(expanded_mask, [bsz, 1, tgt_len, src_len]) - - inverted_mask = 1.0 - tf.cast(expanded_mask, dtype) - - return tf.where( - tf.cast(inverted_mask, bool), tf.fill(dims=shape_list(inverted_mask), value=tf.float32.min), inverted_mask - ) - - -class TFIdeficsRMSNorm(tf.keras.layers.Layer): - def __init__(self, hidden_size, eps=1e-6, **kwargs): - """ - TFIdeficsRMSNorm is equivalent to T5LayerNorm - """ - super().__init__(**kwargs) - self.hidden_size = hidden_size - self.variance_epsilon = eps - - def build(self, input_shape): - if self.built: - return - self.built = True - self.weight = self.add_weight(name="weight", shape=[self.hidden_size], initializer="ones") - - super().build(input_shape) - - def call(self, hidden_states): - variance = tf.math.reduce_mean(tf.math.square(tf.cast(hidden_states, tf.float32)), axis=-1, keepdims=True) - hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [tf.float16, tf.bfloat16]: - hidden_states = tf.cast(hidden_states, self.weight.dtype) - - return self.weight * hidden_states - - -class TFIdeficsEmbedding(tf.keras.layers.Layer): - def __init__(self, dim, max_position_embeddings=2048, base=10000, **kwargs): - super().__init__(**kwargs) - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.inv_freq = tf.constant( - 1.0 / (self.base ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim)) - ) - - def _compute_cos_sin(self, seq_len): - t = tf.range(seq_len, dtype=self.inv_freq.dtype) - freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication - emb = tf.concat((freqs, freqs), axis=-1) - - return tf.cos(emb), tf.sin(emb) - - def call(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len is None: - seq_len = shape_list(x)[2] - return self._compute_cos_sin(seq_len=seq_len) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return tf.concat((-x2, x1), axis=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - cos = tf.gather(cos, position_ids) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] - sin = tf.gather(sin, position_ids) - cos = tf.expand_dims(cos, 1) - sin = tf.expand_dims(sin, 1) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class TFIdeficsMLP(tf.keras.layers.Layer): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - **kwargs, - ): - super().__init__(**kwargs) - self.gate_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="gate_proj") - self.down_proj = tf.keras.layers.Dense(hidden_size, use_bias=False, name="down_proj") - self.up_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="up_proj") - self.act_fn = get_tf_activation(hidden_act) - self.intermediate_size = intermediate_size - self.hidden_size = hidden_size - - def call(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "gate_proj", None) is not None: - with tf.name_scope(self.gate_proj.name): - self.gate_proj.build(self.hidden_size) - if getattr(self, "down_proj", None) is not None: - with tf.name_scope(self.down_proj.name): - self.down_proj.build(self.intermediate_size) - if getattr(self, "up_proj", None) is not None: - with tf.name_scope(self.up_proj.name): - self.up_proj.build(self.hidden_size) - - -class TFIdeficsAttention(tf.keras.layers.Layer): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - hidden_size: int, - num_heads: int, - dropout: float = 0.0, - is_cross_attention: bool = False, - config: IdeficsConfig = None, - qk_layer_norms: bool = False, - **kwargs, - ): - super().__init__(**kwargs) - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - self.dropout = dropout - self.config = config - self.is_causal = True - - if (self.head_dim * num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {num_heads})." - ) - - self.is_cross_attention = is_cross_attention - - self.q_proj = tf.keras.layers.Dense( - num_heads * self.head_dim, - use_bias=False, - name="q_proj", - ) - self.k_proj = tf.keras.layers.Dense( - num_heads * self.head_dim, - use_bias=False, - name="k_proj", - ) - self.v_proj = tf.keras.layers.Dense( - num_heads * self.head_dim, - use_bias=False, - name="v_proj", - ) - self.o_proj = tf.keras.layers.Dense( - hidden_size, - use_bias=False, - name="o_proj", - ) - self.rotary_emb = TFIdeficsEmbedding(self.head_dim, name="rotary_emb") - - self.qk_layer_norms = qk_layer_norms - if self.qk_layer_norms: - self.q_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="q_layer_norm") - self.k_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="k_layer_norm") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - past_key_value: tuple[tf.Tensor] | None = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> tuple[tf.Tensor, tf.Tensor | None, tuple[tf.Tensor] | None]: - # if key_value_states are provided this layer is used as a cross-attention layer - is_cross_attention = self.is_cross_attention or key_value_states is not None - - bsz, q_len, _ = shape_list(hidden_states) - - query_states = self._shape(self.q_proj(hidden_states), q_len, bsz) - if not is_cross_attention: - key_states = self._shape(self.k_proj(hidden_states), q_len, bsz) - value_states = self._shape(self.v_proj(hidden_states), q_len, bsz) - else: - _, kv_len, _ = shape_list(key_value_states) # Note that, in this case, `kv_len` == `kv_seq_len` - key_states = self._shape(self.k_proj(key_value_states), kv_len, bsz) - value_states = self._shape(self.v_proj(key_value_states), kv_len, bsz) - - kv_seq_len = shape_list(key_states)[-2] - if past_key_value is not None: - kv_seq_len += shape_list(past_key_value[0])[-2] - if not is_cross_attention: - # Below is to allow symbolic tensors compilation - if tf.is_tensor(kv_seq_len): - seq_len = tf.reduce_max(kv_seq_len, q_len) - else: - seq_len = max(kv_seq_len, q_len) - cos, sin = self.rotary_emb(value_states, seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - - past_key_value = (key_states, value_states) if use_cache else None - - if self.qk_layer_norms: - query_states = self.q_layer_norm(query_states) - key_states = self.k_layer_norm(key_states) - - tf.debugging.assert_equal( - tf.shape(attention_mask), - [bsz, 1, q_len, kv_seq_len], - message=f"Attention weights should be of size {[bsz, 1, q_len, kv_seq_len]}, but is {tf.shape(attention_mask)}", - ) - - attn_output = scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - - tf.debugging.assert_equal( - tf.shape(attn_output), - [bsz, self.num_heads, q_len, self.head_dim], - message=f"Attention weights should be of size {[bsz, self.num_heads, q_len, self.head_dim]}, but is {tf.shape(attn_output)}", - ) - - attn_output = tf.reshape(tf.transpose(attn_output, perm=[0, 2, 1, 3]), (bsz, q_len, self.hidden_size)) - - attn_output = self.o_proj(attn_output) - - attn_weights = None - if output_attentions: - logger.warning_once( - "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead" - ) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if self.is_cross_attention: - kv_input_dim = ( - self.hidden_size - if not hasattr(self.config.vision_config, "embed_dim") - else self.config.vision_config.embed_dim - ) - else: - kv_input_dim = self.hidden_size - if getattr(self, "o_proj", None) is not None: - with tf.name_scope(self.o_proj.name): - self.o_proj.build(self.num_heads * self.head_dim) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build(self.hidden_size) - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build(kv_input_dim) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build(kv_input_dim) - if getattr(self, "rotary_emb", None) is not None: - with tf.name_scope(self.rotary_emb.name): - self.rotary_emb.build(None) - - -class TFIdeficsDecoderLayer(tf.keras.layers.Layer): - def __init__(self, config: IdeficsConfig, **kwargs): - super().__init__(**kwargs) - self.hidden_size = config.hidden_size - self.self_attn = TFIdeficsAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - dropout=config.dropout, - config=config, - name="self_attn", - ) - self.mlp = TFIdeficsMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - name="mlp", - ) - self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm") - self.post_attention_layernorm = TFIdeficsRMSNorm( - config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm" - ) - self.dropout = config.dropout - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - past_key_value: tuple[tf.Tensor] | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, - training=False, - ) -> tuple[tf.Tensor, tuple[tf.Tensor, tf.Tensor] | None]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "input_layernorm", None) is not None: - with tf.name_scope(self.input_layernorm.name): - self.input_layernorm.build(None) - if getattr(self, "post_attention_layernorm", None) is not None: - with tf.name_scope(self.post_attention_layernorm.name): - self.post_attention_layernorm.build(None) - - -class TFIdeficsGatedCrossAttentionLayer(tf.keras.layers.Layer): - def __init__(self, config: IdeficsConfig, **kwargs): - super().__init__(**kwargs) - self.hidden_size = config.hidden_size - self.cross_attn = TFIdeficsAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - is_cross_attention=True, - dropout=config.dropout, - config=config, - qk_layer_norms=config.qk_layer_norms, - name="cross_attn", - ) - self.mlp = TFIdeficsMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - name="mlp", - ) - self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm") - self.post_attention_layernorm = TFIdeficsRMSNorm( - config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm" - ) - self.config = config.dropout - - self.act_cross_attn = tf.keras.activations.tanh - self.act_dense = tf.keras.activations.tanh - - self.alpha_initializer = config.alpha_initializer - self.alpha_type = config.alpha_type - self.alphas_initializer_range = config.alphas_initializer_range - - def build(self, input_shape): - if self.built: - return - self.built = True - if self.alpha_initializer == "zeros": - if self.alpha_type == "vector": - self.alpha_cross_attn = self.add_weight( - shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_cross_attn" - ) - self.alpha_dense = self.add_weight( - shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_dense" - ) - elif self.alpha_type == "float": - self.alpha_cross_attn = self.add_weight( - shape=(1,), initializer="zeros", trainable=True, name="alpha_cross_attn" - ) - self.alpha_dense = self.add_weight(shape=(1,), initializer="zeros", trainable=True, name="alpha_dense") - else: - raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})") - - elif self.alpha_initializer == "ones": - if self.alpha_type == "vector": - self.alpha_cross_attn = self.add_weight( - shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_cross_attn" - ) - self.alpha_dense = self.add_weight( - shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_dense" - ) - elif self.alpha_type == "float": - self.alpha_cross_attn = self.add_weight( - shape=(1,), initializer="ones", trainable=True, name="alpha_cross_attn" - ) - self.alpha_dense = self.add_weight(shape=(1,), initializer="ones", trainable=True, name="alpha_dense") - else: - raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})") - - elif self.alpha_initializer in {"normal", "gaussian", "random"}: - if self.alpha_type == "vector": - self.alpha_cross_attn = self.add_weight( - shape=(1, 1, self.hidden_size), - initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), - trainable=True, - name="alpha_cross_attn", - ) - self.alpha_dense = self.add_weight( - shape=(1, 1, self.hidden_size), - initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), - trainable=True, - name="alpha_dense", - ) - elif self.alpha_type == "float": - self.alpha_cross_attn = self.add_weight( - shape=(1,), - initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), - trainable=True, - name="alpha_type", - ) - self.alpha_dense = self.add_weight( - shape=(1,), - initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range), - trainable=True, - name="alpha_dense", - ) - else: - raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})") - - else: - raise NotImplementedError(f"Alpha initialization scheme {self.alpha_initializer} not yet implemented!") - - if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")): - raise ValueError("Alpha parameters not initialized correctly!") - with tf.name_scope(self.cross_attn.name): - self.cross_attn.build(None) - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - with tf.name_scope(self.input_layernorm.name): - self.input_layernorm.build(None) - with tf.name_scope(self.post_attention_layernorm.name): - self.post_attention_layernorm.build(None) - super().build(input_shape) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - image_hidden_states: tf.Tensor | None = None, - image_attention_mask: tf.Tensor | None = None, - cross_attention_gate: tf.Tensor | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, - past_key_value: tuple[tf.Tensor] | None = None, - ) -> tuple[tf.Tensor, tuple[tf.Tensor, tf.Tensor] | None]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states - no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored - """ - if image_hidden_states is None: - raise ValueError( - "`image_hidden_states` is required for Idefics cross attention module which are visual features to be" - " conditioned on." - ) - - if cross_attention_gate is None: - raise ValueError( - "`cross_attention_gate` is required for Idefics cross attention module to zero-out the cross-attention hidden_states attending to no images." - ) - - if past_key_value is not None: - raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.") - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.cross_attn( - hidden_states=hidden_states, - key_value_states=image_hidden_states, - attention_mask=image_attention_mask, - output_attentions=output_attentions, - ) - hidden_states = tf.nn.dropout(hidden_states, rate=self.config) - mask = tf.cast(cross_attention_gate == 0, dtype=hidden_states.dtype) - # Expand dimensions of mask to match hidden_states - mask = tf.expand_dims(mask, -1) - hidden_states = tf.where( - tf.broadcast_to(mask, tf.shape(hidden_states)) == 1, tf.zeros_like(hidden_states), hidden_states - ) - # when there are no images the model is used in pure language mode - # gate = 0 if no_images else 1 - hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = tf.nn.dropout(hidden_states, rate=self.config) - hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a TensorFlow [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) subclass. - Use it as a regular TensorFlow Layer and refer to the TensorFlow documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`IdeficsConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class TFIdeficsPreTrainedModel(TFPreTrainedModel): - config_class = IdeficsConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["TFIdeficsDecoderLayer", "TFIdeficsGatedCrossAttentionLayer"] - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -@keras_serializable -class TFIdeficsMainLayer(tf.keras.layers.Layer): - """ - Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`] - - Args: - config: IdeficsConfig - """ - - config_class = IdeficsConfig - - def __init__(self, config: IdeficsConfig, add_pooling_year: bool = True, **kwargs): - super().__init__(**kwargs) - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = TFIdeficsDecoupledEmbedding( - num_embeddings=config.vocab_size, - num_additional_embeddings=config.additional_vocab_size, - embedding_dim=config.hidden_size, - partially_freeze=config.freeze_text_layers, - name="embed_tokens", - ) - - self.image_size = config.vision_config.image_size - self.vision_config = config.vision_config - self.vision_model = TFIdeficsVisionTransformer(config.vision_config, name="vision_model") - - # Perceiver Resampler - if config.use_resampler: - perceiver_config = config.perceiver_config - self.perceiver_resampler = TFIdeficsPerceiverResampler( - config, - config.vision_config.embed_dim, - perceiver_config.resampler_depth, - perceiver_config.resampler_n_heads, - perceiver_config.resampler_head_dim, - perceiver_config.resampler_n_latents, - name="perceiver_resampler", - ) - - self.decoder_layers = [ - TFIdeficsDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers) - ] - - self.cross_layer_interval = config.cross_layer_interval - num_cross_layers = config.num_hidden_layers // self.cross_layer_interval - self.gated_cross_attn_layers = [ - TFIdeficsGatedCrossAttentionLayer(config, name=f"gated_cross_attn_layers.{i}") - for i in range(num_cross_layers) - ] - self.gradient_checkpointing = False - - self.norm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="norm") - - self.gradient_checkpointing = False - self.freeze_relevant_params(config) - - def freeze_relevant_params(self, config=None): - if config is None: - config = self.config - - if config.freeze_text_layers: - self.freeze_text_layers(config.freeze_text_module_exceptions) - - if config.freeze_vision_layers: - freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions) - - def freeze_text_layers(self, module_exceptions=[]): - for module in [self.decoder_layers, self.norm]: - freeze_model(module, module_exceptions=module_exceptions) - - def freeze_vision_layers(self, module_exceptions=[]): - freeze_model(self.vision_model, module_exceptions=module_exceptions) - - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - # if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @unpack_inputs - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - past_key_values: list[tf.Tensor] | None = None, - inputs_embeds: tf.Tensor | None = None, - pixel_values: tf.Tensor | None = None, - image_encoder_embeddings: tf.Tensor | None = None, - perceiver_embeddings: tf.Tensor | None = None, - image_attention_mask: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool | None = False, - return_dict: bool | None = None, - training: bool | None = None, - ) -> TFIdeficsBaseModelOutputWithPast | tuple[tf.Tensor]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = shape_list(input_ids) - elif inputs_embeds is not None: - batch_size, seq_length, _ = shape_list(inputs_embeds) - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = shape_list(past_key_values[0][0])[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = tf.math.cumsum(tf.cast(attention_mask, dtype=tf.int32), axis=-1) - 1 - position_ids = tf.where(attention_mask == 0, 1, position_ids) - elif position_ids is None: - position_ids = tf.range(past_key_values_length, seq_length + past_key_values_length, dtype=tf.int32) - position_ids = tf.expand_dims(position_ids, 0) - - no_images = False - if ( - sum((int(pixel_values is None), int(image_encoder_embeddings is None), int(perceiver_embeddings is None))) - != 2 - ): - raise ValueError( - "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None." - ) - - elif pixel_values is not None: - no_images = tf.reduce_sum(tf.cast(pixel_values, dtype=tf.int32)) == 0 - pixel_values = tf.cast(pixel_values, dtype=self.dtype) # fp16 compatibility - # Below hack is because when cross-loading pytorch weights, there is an - # initial forward pass with dummy input and code below is here to handle that - if len(pixel_values.shape) == 4: - batch_size = shape_list(pixel_values)[0] - num_images = shape_list(pixel_values)[0] - # pixel_values = tf.reshape(pixel_values, [batch_size * num_images, *pixel_values.shape[1:]]) - elif len(pixel_values.shape) == 5: - batch_size, num_images = shape_list(pixel_values)[:2] - pixel_values = tf.reshape(pixel_values, [batch_size * num_images, *pixel_values.shape[2:]]) - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding - ).last_hidden_state - - elif image_encoder_embeddings is not None: - batch_size, num_images, image_seq_len, image_hidden_size = shape_list(image_encoder_embeddings) - image_hidden_states = tf.cast(image_encoder_embeddings, dtype=self.dtype) - image_hidden_states = tf.reshape( - image_hidden_states, (batch_size * num_images, image_seq_len, image_hidden_size) - ) - - if self.config.use_resampler: - if perceiver_embeddings is None: - perceiver_embeddings = self.perceiver_resampler(image_hidden_states) - image_seq_len, image_hidden_size = shape_list(perceiver_embeddings)[1:3] - else: - batch_size, num_images, image_seq_len, image_hidden_size = shape_list(perceiver_embeddings) - image_hidden_states = perceiver_embeddings - elif perceiver_embeddings is None: - image_seq_len, image_hidden_size = shape_list(image_hidden_states)[1:3] - else: - raise ValueError("If `perceiver_embeddings` are passed, use_resampler should be True") - - image_hidden_states = tf.reshape( - image_hidden_states, (batch_size, num_images * image_seq_len, image_hidden_size) - ) - # # Hack to use the model in full language modeling mode - # image_attention_mask = tf.zeros((batch_size, seq_length, 1), dtype=tf.int32) - - # this is to account for the dummy inputs - if pixel_values is not None and len(pixel_values.shape) == 4 and image_attention_mask is None: - image_attention_mask = tf.zeros((batch_size, seq_length, 1), dtype=tf.int32) - - text_seq_len = shape_list(image_attention_mask)[1] - image_attention_mask = tf.expand_dims(image_attention_mask, -1) - image_attention_mask = tf.repeat(image_attention_mask, repeats=image_seq_len) - image_attention_mask = tf.reshape(image_attention_mask, (batch_size, text_seq_len, num_images * image_seq_len)) - - if image_hidden_states is not None: - image_batch_size, image_sequence_length, _ = shape_list(image_hidden_states) - image_hidden_shape = (image_batch_size, image_sequence_length) - if image_attention_mask is None: - image_attention_mask = tf.ones(image_hidden_shape, dtype=tf.int32) - image_attention_mask = invert_attention_mask(image_attention_mask) - else: - image_attention_mask = None - - cross_attention_gate = tf.squeeze( - tf.cast(tf.reduce_any(image_attention_mask == 0, axis=-1), dtype=self.dtype), axis=1 - ) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = tf.ones((batch_size, seq_length_with_past), dtype=tf.bool) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.decoder_layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - def vblock( - main_block, - hidden_states, - attention_mask, - position_ids, - past_key_value, - image_hidden_states, - image_attention_mask, - cross_attention_gate, - output_attentions, - use_cache, - layer_idx, - cross_layer_interval, - gated_cross_attn_layers, - ): - # TODO(ls): Add cross attention values to respective lists - if layer_idx % cross_layer_interval == 0: - xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval] - outputs = xblock( - hidden_states, - attention_mask=attention_mask, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - cross_attention_gate=cross_attention_gate, - output_attentions=output_attentions, - use_cache=use_cache, - past_key_value=None, # not implemented - ) - hidden_states = outputs[0] - - layer_outputs = main_block( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - return layer_outputs - - if self.gradient_checkpointing and training: - past_key_value = None - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - layer_outputs = tf.recompute_grad( - vblock, - decoder_layer, - hidden_states, - attention_mask, - position_ids, - past_key_value, - image_hidden_states, - image_attention_mask, - output_attentions, - use_cache, - no_images, - idx, - self.cross_layer_interval, - self.gated_cross_attn_layers, - ) - else: - layer_outputs = vblock( - decoder_layer, - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - cross_attention_gate=cross_attention_gate, - output_attentions=output_attentions, - use_cache=use_cache, - layer_idx=idx, - cross_layer_interval=self.cross_layer_interval, - gated_cross_attn_layers=self.gated_cross_attn_layers, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - image_hidden_states = tf.reshape( - image_hidden_states, (batch_size, num_images, image_seq_len, image_hidden_size) - ) - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states] - if v is not None - ) - return TFIdeficsBaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - image_hidden_states=image_hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_tokens", None) is not None: - with tf.name_scope(self.embed_tokens.name): - self.embed_tokens.build(None) - if getattr(self, "vision_model", None) is not None: - with tf.name_scope(self.vision_model.name): - self.vision_model.build(None) - if getattr(self, "norm", None) is not None: - with tf.name_scope(self.norm.name): - self.norm.build(None) - if getattr(self, "perceiver_resampler", None) is not None: - with tf.name_scope(self.perceiver_resampler.name): - self.perceiver_resampler.build(None) - if getattr(self, "decoder_layers", None) is not None: - for layer in self.decoder_layers: - with tf.name_scope(layer.name): - layer.build(None) - if getattr(self, "gated_cross_attn_layers", None) is not None: - for layer in self.gated_cross_attn_layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFIdeficsModel(TFIdeficsPreTrainedModel): - def __init__(self, config: IdeficsConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.model = TFIdeficsMainLayer(config, name="model") - - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - past_key_values: list[tf.Tensor] | None = None, - inputs_embeds: tf.Tensor | None = None, - pixel_values: tf.Tensor | None = None, - image_encoder_embeddings: tf.Tensor | None = None, - perceiver_embeddings: tf.Tensor | None = None, - image_attention_mask: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool | None = False, - return_dict: bool | None = None, - training: bool | None = None, - ) -> TFIdeficsBaseModelOutputWithPast | tuple[tf.Tensor]: - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - image_encoder_embeddings=image_encoder_embeddings, - perceiver_embeddings=perceiver_embeddings, - image_attention_mask=image_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, - training=training, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -class TFIdeficsForVisionText2Text(TFPreTrainedModel, TFCausalLanguageModelingLoss): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] - config_class = IdeficsConfig - - def __init__(self, config, vision_model=None, **kwargs): - super().__init__(config, **kwargs) - self.model = TFIdeficsMainLayer(config, name="model") - self.lm_head = TFIdeficsDecoupledLinear( - config.hidden_size, - config.vocab_size, - config.additional_vocab_size, - bias=False, - partially_freeze=config.freeze_lm_head, - name="lm_head", - ) - - def tie_weights(self): - """ - Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of - IdeficsDecoupledLinear and IdeficsDecoupledEmbedding. - """ - output_embeddings = self.get_output_embeddings() - input_embeddings = self.get_input_embeddings() - - if getattr(self.config, "tie_word_embeddings", True): - output_embeddings.weight = input_embeddings.weight - if input_embeddings.num_additional_embeddings > 0: - assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings - output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight - - if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): - output_embeddings.out_features = input_embeddings.num_embeddings - if hasattr(output_embeddings, "out_additional_features") and hasattr( - input_embeddings, "num_additional_embeddings" - ): - output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings - - @unpack_inputs - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFIdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - past_key_values: list[tf.Tensor] | None = None, - inputs_embeds: tf.Tensor | None = None, - pixel_values: tf.Tensor | None = None, - image_encoder_embeddings: tf.Tensor | None = None, - perceiver_embeddings: tf.Tensor | None = None, - image_attention_mask: tf.Tensor | None = None, - labels: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool | None = False, - return_dict: bool | None = None, - training=False, - ) -> TFIdeficsCausalLMOutputWithPast | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >> from transformers import AutoTokenizer, TFIdeficsForVisionText2Text - - >> model = TFIdeficsForVisionText2Text.from_pretrained("HuggingFaceM4/idefics-9b") - >> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceM4/idefics-9b") - - >> prompt = "Hey, are you consciours? Can you talk to me?" - >> inputs = tokenizer(prompt, return_tensors="tf") - - >> # Generate - >> generate_ids = model.generate(inputs.input_ids, max_length=30) - >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - image_encoder_embeddings=image_encoder_embeddings, - perceiver_embeddings=perceiver_embeddings, - image_attention_mask=image_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, - training=training, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:] - shift_logits = logits[..., :-1, :][shift_attention_mask != 0] - shift_labels = labels[..., 1:][shift_attention_mask != 0] - else: - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - # Flatten the tokens - loss = self.hf_compute_loss( - labels=tf.reshape(shift_labels, [-1]), logits=tf.reshape(shift_logits, [-1, shift_logits.shape[-1]]) - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return TFIdeficsCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) - - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - image_hidden_states = kwargs.pop("image_hidden_states", None) - if image_hidden_states is not None: - if self.config.use_resampler: - kwargs["perceiver_embeddings"] = image_hidden_states - else: - kwargs["image_encoder_embeddings"] = image_hidden_states - kwargs["pixel_values"] = None - inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) - unwanted_kwargs = ["token_type_ids"] - for kwarg in unwanted_kwargs: - inputs.pop(kwarg, None) - return inputs - - @staticmethod - def _expand_inputs_for_generation( - *args, - **model_kwargs, - ): - return expand_inputs_for_generation(*args, **model_kwargs) - - @staticmethod - def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder): - return update_model_kwargs_for_generation(outputs, model_kwargs) - - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(tf.gather(past_state, beam_idx) for past_state in layer_past),) - return reordered_past - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -__all__ = ["TFIdeficsForVisionText2Text", "TFIdeficsModel", "TFIdeficsPreTrainedModel"] diff --git a/src/transformers/models/idefics/perceiver_tf.py b/src/transformers/models/idefics/perceiver_tf.py deleted file mode 100644 index a4de96b68e78..000000000000 --- a/src/transformers/models/idefics/perceiver_tf.py +++ /dev/null @@ -1,195 +0,0 @@ -# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License. -# -# MIT License -# -# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - - -""" - -Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially -time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note -that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to -prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that -to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore. - -References: - - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model - - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch - -""" - -from typing import Optional - -import tensorflow as tf - -from ...modeling_tf_utils import shape_list -from .configuration_idefics import IdeficsConfig - - -class TFIdeficsPerceiverResampler(tf.keras.layers.Layer): - def __init__( - self, config: IdeficsConfig, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int, **kwargs - ) -> None: - """ - Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or - MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then - returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed - to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler. - Could be e.g., VIT embed_dim, ResNet pool dim, and so on. - - Args: - config (`IdeficsConfig`): config object - embed_dim (`int`): The size of each embedding vector - depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). - n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention). - head_dim (`int`): Dimensionality of each head projection in the Transformer block. - n_latents (`int`): - Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). - - """ - super().__init__(**kwargs) - self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents - self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver - - self.intermediate_dim = ( - self.embed_dim * 4 - if not hasattr(config.vision_config, "embed_dim") - else config.vision_config.embed_dim * 4 - ) - # Create Transformer Blocks - self.blocks = [] - for i in range(depth): - self.blocks.append( - [ - TFIdeficsPerceiverAttention( - self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms, name=f"blocks.{i}.0" - ), - TFIdeficsMLP(self.intermediate_dim, config, name=f"blocks.{i}.1"), - ] - ) - - self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - - def build(self, input_shape): - # Create Latents for Perceiver - self.latents = self.add_weight( - shape=(self.n_latents, self.embed_dim), initializer="random_normal", trainable=True, name="latents" - ) - super().build(input_shape) - - def call(self, context: tf.Tensor) -> tf.Tensor: - """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings""" - # tf.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0]) - latents = tf.expand_dims(self.latents, axis=0) - latents = tf.tile(latents, [tf.shape(context)[0], 1, 1]) - # Feed through Perceiver Attention blocks... - for attn, ff in self.blocks: - latents = attn(context, latents) + latents - latents = ff(latents) + latents - return self.layer_norm(latents) - - -class TFIdeficsPerceiverAttention(tf.keras.layers.Layer): - def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool, **kwargs) -> None: - """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" - super().__init__(**kwargs) - self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim - self.qk_layer_norms = qk_layer_norms - # Normalization & Scaling - self.context_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="context_layer_norm") - self.latents_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="latents_layer_norm") - if self.qk_layer_norms: - self.q_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="q_layer_norm") - self.k_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="k_layer_norm") - - self.qk_scale = self.head_dim**-0.5 - - # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers). - self.q_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="q_proj") - self.k_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="k_proj") - self.v_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="v_proj") - - self.output_proj = tf.keras.layers.Dense(embed_dim, use_bias=False, name="output_proj") - - def call(self, context: tf.Tensor, latents: tf.Tensor) -> tf.Tensor: - """ - Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! - - Args: - context (`tf.Tensor`): - Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample. - latents (`tf.Tensor`): - Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to. - - Returns: - `tf.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross - from context. - """ - context = self.context_layer_norm(context) - latents = self.latents_layer_norm(latents) - batch_size, seq_length, embed_dim = shape_list(context) - - # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn! - # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents` - q = self.q_proj(latents) - k = self.k_proj(tf.concat([context, latents], axis=-2)) - v = self.v_proj(tf.concat([context, latents], axis=-2)) - - # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call) - # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)] - q, k, v = [ - tf.transpose(tf.reshape(x, (batch_size, x.shape[1], self.n_heads, self.head_dim)), perm=[0, 2, 1, 3]) - for x in (q, k, v) - ] - - if self.qk_layer_norms: - q = self.q_layer_norm(q) - k = self.k_layer_norm(k) - - scores = tf.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k) - stabilized_scores = scores - tf.reduce_max(scores, axis=-1, keepdims=True) - attn = tf.nn.softmax(stabilized_scores, axis=-1) - - # Attend & project back to output... - resampled = tf.einsum("... i j, ... j d -> ... i d", attn, v) - return self.output_proj( - tf.reshape(tf.transpose(resampled, perm=[0, 2, 1, 3]), (batch_size, -1, self.n_heads * self.head_dim)) - ) - - -class TFIdeficsMLP(tf.keras.layers.Layer): - def __init__(self, intermediate_size, config: IdeficsConfig, **kwargs): - """Simple MLP block with intermediate_size and embedding size""" - super().__init__(**kwargs) - self.embed_dim = config.vision_config.embed_dim - self.ln = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="ln") - self.fc = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="fc") - self.act = tf.keras.layers.ReLU(name="act") - self.c_proj = tf.keras.layers.Dense(self.embed_dim, use_bias=False, name="c_proj") - - def call(self, hidden_states: Optional[tuple[tf.Tensor]]) -> tf.Tensor: - hidden_states = self.ln(hidden_states) - hidden_states = self.fc(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states) - - return hidden_states diff --git a/src/transformers/models/idefics/vision_tf.py b/src/transformers/models/idefics/vision_tf.py deleted file mode 100644 index 1d8cf9402218..000000000000 --- a/src/transformers/models/idefics/vision_tf.py +++ /dev/null @@ -1,572 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF IdeficsVision model: a copy of CLIPVisionModel using a simpler config object""" - -import math -from dataclasses import dataclass -from typing import Optional, Union - -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling -from ...modeling_tf_utils import TFPreTrainedModel, shape_list -from ...tf_utils import flatten -from ...utils import ModelOutput, logging -from .configuration_idefics import IdeficsVisionConfig - - -logger = logging.get_logger(__name__) - - -@dataclass -class TFIdeficsVisionModelOutput(ModelOutput): - """ - Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. - - Args: - image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - image_embeds: Optional[tf.Tensor] = None - last_hidden_state: Optional[tf.Tensor] = None - hidden_states: Optional[tuple[tf.Tensor]] = None - attentions: Optional[tuple[tf.Tensor]] = None - - -class TFIdeficsVisionEmbeddings(tf.keras.layers.Layer): - def __init__(self, config: IdeficsVisionConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.patch_embedding = tf.keras.layers.Conv2D( - filters=self.embed_dim, - kernel_size=self.patch_size, - strides=self.patch_size, - use_bias=False, - padding="valid", - data_format="channels_last", - name="patch_embedding", - ) - - self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches + 1 - self.position_embedding = tf.keras.layers.Embedding( - self.num_positions, self.embed_dim, name="position_embedding" - ) - # self.position_ids = tf.range(self.num_positions)[tf.newaxis, :] - - def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor: - num_patches = shape_list(embeddings)[1] - 1 - pos_embed = self.position_embedding(self.position_ids) - num_positions = shape_list(pos_embed)[1] - 1 - if num_patches == num_positions and height == width: - return pos_embed - class_pos_embed = pos_embed[:, 0] - patch_pos_embed = pos_embed[:, 1:] - - embed_dim = shape_list(embeddings)[-1] - num_h_patches = height // self.config.patch_size - num_w_patches = width // self.config.patch_size - num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1 - sqrt_num_positions = math.sqrt(float(num_positions)) - patch_pos_embed = tf.reshape(patch_pos_embed, (1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim)) - - scale_height = num_h_patches / sqrt_num_positions - scale_width = num_w_patches / sqrt_num_positions - original_height = tf.cast(tf.shape(patch_pos_embed)[1], tf.float32) - original_width = tf.cast(tf.shape(patch_pos_embed)[2], tf.float32) - # Apply scaling - new_height = tf.cast(original_height * scale_height, tf.int32) - new_width = tf.cast(original_width * scale_width, tf.int32) - - patch_pos_embed = tf.image.resize( - patch_pos_embed, size=[new_height, new_width], method=tf.image.ResizeMethod.BICUBIC - ) - - if ( - int(num_h_patches) != shape_list(patch_pos_embed)[-3] - or int(num_w_patches) != shape_list(patch_pos_embed)[-2] - ): - raise ValueError( - f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the " - f"shape of position embedding ({shape_list(patch_pos_embed)[-2], shape_list(patch_pos_embed)[-1]})" - ) - patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, embed_dim)) - return tf.concat((class_pos_embed[tf.newaxis, :], patch_pos_embed), axis=1) - - def call(self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False) -> tf.Tensor: - # Input `pixel_values` is NCHW format which doesn't run on CPU so first thing we do is - # transpose it to change it to NHWC. We don't care to transpose it back because - # the Conv2D layer is only hit once for each query - - if isinstance(pixel_values, dict): - pixel_values = pixel_values["pixel_values"] - - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - batch_size, height, width, num_channels = shape_list(pixel_values) - if not interpolate_pos_encoding: - if height != self.image_size or width != self.image_size: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`" - ) - - patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] - # Change the 2D spatial dimensions to a single temporal dimension. - # shape = (batch_size, num_patches, out_channels=embed_dim) - patch_embeds = flatten(patch_embeds, 1, 2) - - class_embeds = tf.broadcast_to( - self.class_embedding[tf.newaxis, tf.newaxis, :], [batch_size, 1, self.embed_dim] - ) - embeddings = tf.concat([class_embeds, patch_embeds], axis=1) - - # add positional encoding to each token - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - embeddings = embeddings + self.position_embedding(self.position_ids) - - return embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - self.position_ids = tf.range(self.num_positions, name="self.position_ids")[tf.newaxis, :] - self.class_embedding = self.add_weight(shape=(self.embed_dim,), name="class_embedding") - if getattr(self, "patch_embedding", None) is not None: - with tf.name_scope(self.patch_embedding.name): - self.patch_embedding.build([None, None, None, self.config.num_channels]) - if getattr(self, "position_embedding", None) is not None: - with tf.name_scope(self.position_embedding.name): - self.position_embedding.build(None) - - -class TFIdeficsVisionAttention(tf.keras.layers.Layer): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout - - self.k_proj = tf.keras.layers.Dense(self.embed_dim, name="k_proj") - self.v_proj = tf.keras.layers.Dense(self.embed_dim, name="v_proj") - self.q_proj = tf.keras.layers.Dense(self.embed_dim, name="q_proj") - self.out_proj = tf.keras.layers.Dense(self.embed_dim, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: Optional[tf.Tensor] = None, - causal_attention_mask: Optional[tf.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> tuple[tf.Tensor, Optional[tf.Tensor], Optional[tuple[tf.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.linalg.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - tf.shape(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, src_len]}, but is {tf.shape(attn_weights)}", - ) - - # apply the causal_attention_mask first - if causal_attention_mask is not None: - if shape_list(causal_attention_mask) != [bsz, 1, tgt_len, src_len]: - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(causal_attention_mask)}" - ) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + causal_attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - if attention_mask is not None: - if shape_list(attention_mask) != [bsz, 1, tgt_len, src_len]: - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}" - ) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = tf.nn.softmax(attn_weights, axis=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - attn_weights = tf.reshape(attn_weights_reshaped, (bsz * self.num_heads, tgt_len, src_len)) - else: - attn_weights_reshaped = None - - attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) - - attn_output = tf.linalg.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - tf.shape(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, self.head_dim]}, but is {tf.shape(attn_output)}", - ) - - attn_output = tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)) - attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3]) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build((self.embed_dim, self.embed_dim)) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build((self.embed_dim, self.embed_dim)) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build((self.embed_dim, self.embed_dim)) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build((self.embed_dim, self.embed_dim)) - - -class TFIdeficsVisionMLP(tf.keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.config = config - self.activation_fn = get_tf_activation(config.hidden_act) - self.fc1 = tf.keras.layers.Dense(config.intermediate_size, name="fc1") - self.fc2 = tf.keras.layers.Dense(config.hidden_size, name="fc2") - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build(self.config.hidden_size) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build(self.config.intermediate_size) - - -class TFIdeficsVisionEncoderLayer(tf.keras.layers.Layer): - def __init__(self, config: IdeficsVisionConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.hidden_size - self.self_attn = TFIdeficsVisionAttention(config, name="self_attn") - self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") - self.mlp = TFIdeficsVisionMLP(config, name="mlp") - self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - causal_attention_mask: tf.Tensor, - output_attentions: Optional[bool] = False, - ) -> tuple[tf.Tensor]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - `(config.encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer_norm1", None) is not None: - with tf.name_scope(self.layer_norm1.name): - self.layer_norm1.build([None, None, self.embed_dim]) - if getattr(self, "layer_norm2", None) is not None: - with tf.name_scope(self.layer_norm2.name): - self.layer_norm2.build([None, None, self.embed_dim]) - - -class TFIdeficsVisionEncoder(tf.keras.layers.Layer): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`TFIdeficsVisionEncoderLayer`]. - - Args: - config: IdeficsVisionConfig - """ - - def __init__(self, config: IdeficsVisionConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layers = [ - TFIdeficsVisionEncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers) - ] - self.gradient_checkpointing = False - - def call( - self, - inputs_embeds, - attention_mask: Optional[tf.Tensor] = None, - causal_attention_mask: Optional[tf.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: Optional[bool] = None, - ) -> Union[tuple, TFBaseModelOutput]: - r""" - Args: - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - causal_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Causal mask for the text model. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = tf.recompute_grad( - create_custom_forward(encoder_layer), - hidden_states, - attention_mask, - causal_attention_mask, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFIdeficsVisionTransformer(TFPreTrainedModel): - def __init__(self, config: IdeficsVisionConfig, **kwargs): - super().__init__(config, **kwargs) - self.config = config - self.embed_dim = config.hidden_size - - self.embeddings = TFIdeficsVisionEmbeddings(config, name="embeddings") - self.pre_layrnorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="pre_layrnorm") - self.encoder = TFIdeficsVisionEncoder(config, name="encoder") - self.post_layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm") - - # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward - def call( - self, - pixel_values: Optional[tf.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = False, - return_dict: Optional[bool] = None, - training: Optional[bool] = False, - ) -> Union[tuple, TFBaseModelOutputWithPooling]: - r""" - Returns: - - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - hidden_states = self.pre_layrnorm(hidden_states) - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - last_hidden_state = encoder_outputs[0] - pooled_output = last_hidden_state[:, 0, :] - pooled_output = self.post_layernorm(pooled_output) - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "pre_layrnorm", None) is not None: - with tf.name_scope(self.pre_layrnorm.name): - self.pre_layrnorm.build([None, None, self.embed_dim]) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "post_layernorm", None) is not None: - with tf.name_scope(self.post_layernorm.name): - self.post_layernorm.build([None, self.embed_dim]) diff --git a/src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py b/src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py deleted file mode 100644 index 182d66b9af28..000000000000 --- a/src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py +++ /dev/null @@ -1,71 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert OpenAI Image GPT checkpoints.""" - -import argparse - -import torch - -from transformers import ImageGPTConfig, ImageGPTForCausalLM, load_tf_weights_in_imagegpt -from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging - - -logging.set_verbosity_info() - - -def convert_imagegpt_checkpoint_to_pytorch(imagegpt_checkpoint_path, model_size, pytorch_dump_folder_path): - # Construct configuration depending on size - MODELS = {"small": (512, 8, 24), "medium": (1024, 8, 36), "large": (1536, 16, 48)} - n_embd, n_head, n_layer = MODELS[model_size] # set model hyperparameters - config = ImageGPTConfig(n_embd=n_embd, n_layer=n_layer, n_head=n_head) - model = ImageGPTForCausalLM(config) - - # Load weights from numpy - load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path) - - # Save pytorch-model - pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME - pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME - print(f"Save PyTorch model to {pytorch_weights_dump_path}") - torch.save(model.state_dict(), pytorch_weights_dump_path) - print(f"Save configuration file to {pytorch_config_dump_path}") - with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: - f.write(config.to_json_string()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--imagegpt_checkpoint_path", - default=None, - type=str, - required=True, - help="Path to the TensorFlow checkpoint path.", - ) - parser.add_argument( - "--model_size", - default=None, - type=str, - required=True, - help="Size of the model (can be either 'small', 'medium' or 'large').", - ) - parser.add_argument( - "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_imagegpt_checkpoint_to_pytorch( - args.imagegpt_checkpoint_path, args.model_size, args.pytorch_dump_folder_path - ) diff --git a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py b/src/transformers/models/layoutlm/modeling_tf_layoutlm.py deleted file mode 100644 index f6738693843b..000000000000 --- a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py +++ /dev/null @@ -1,1691 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 LayoutLM model.""" - -from __future__ import annotations - -import math -import warnings - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFBaseModelOutputWithPoolingAndCrossAttentions, - TFMaskedLMOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_layoutlm import LayoutLMConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LayoutLMConfig" - - -class TFLayoutLMEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config: LayoutLMConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.max_2d_position_embeddings = config.max_2d_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("x_position_embeddings"): - self.x_position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_2d_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("y_position_embeddings"): - self.y_position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_2d_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("h_position_embeddings"): - self.h_position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_2d_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("w_position_embeddings"): - self.w_position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_2d_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - def call( - self, - input_ids: tf.Tensor | None = None, - bbox: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - training: bool = False, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) - - if bbox is None: - bbox = tf.fill(input_shape + [4], value=0) - try: - left_position_embeddings = tf.gather(self.x_position_embeddings, bbox[:, :, 0]) - upper_position_embeddings = tf.gather(self.y_position_embeddings, bbox[:, :, 1]) - right_position_embeddings = tf.gather(self.x_position_embeddings, bbox[:, :, 2]) - lower_position_embeddings = tf.gather(self.y_position_embeddings, bbox[:, :, 3]) - except IndexError as e: - raise IndexError("The `bbox`coordinate values should be within 0-1000 range.") from e - h_position_embeddings = tf.gather(self.h_position_embeddings, bbox[:, :, 3] - bbox[:, :, 1]) - w_position_embeddings = tf.gather(self.w_position_embeddings, bbox[:, :, 2] - bbox[:, :, 0]) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = ( - inputs_embeds - + position_embeds - + token_type_embeds - + left_position_embeddings - + upper_position_embeddings - + right_position_embeddings - + lower_position_embeddings - + h_position_embeddings - + w_position_embeddings - ) - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->LayoutLM -class TFLayoutLMSelfAttention(keras.layers.Layer): - def __init__(self, config: LayoutLMConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - - self.is_decoder = config.is_decoder - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concat([past_key_value[0], key_layer], axis=2) - value_layer = tf.concat([past_key_value[1], value_layer], axis=2) - else: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFLayoutLMModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->LayoutLM -class TFLayoutLMSelfOutput(keras.layers.Layer): - def __init__(self, config: LayoutLMConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->LayoutLM -class TFLayoutLMAttention(keras.layers.Layer): - def __init__(self, config: LayoutLMConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFLayoutLMSelfAttention(config, name="self") - self.dense_output = TFLayoutLMSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - # add attentions (possibly with past_key_value) if we output them - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->LayoutLM -class TFLayoutLMIntermediate(keras.layers.Layer): - def __init__(self, config: LayoutLMConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->LayoutLM -class TFLayoutLMOutput(keras.layers.Layer): - def __init__(self, config: LayoutLMConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->LayoutLM -class TFLayoutLMLayer(keras.layers.Layer): - def __init__(self, config: LayoutLMConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFLayoutLMAttention(config, name="attention") - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TFLayoutLMAttention(config, name="crossattention") - self.intermediate = TFLayoutLMIntermediate(config, name="intermediate") - self.bert_output = TFLayoutLMOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_value: tuple[tf.Tensor] | None, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=self_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - input_tensor=attention_output, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + outputs # add attentions if we output them - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - if getattr(self, "crossattention", None) is not None: - with tf.name_scope(self.crossattention.name): - self.crossattention.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->LayoutLM -class TFLayoutLMEncoder(keras.layers.Layer): - def __init__(self, config: LayoutLMConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layer = [TFLayoutLMLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_values: tuple[tuple[tf.Tensor]] | None, - use_cache: bool | None, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - past_key_value = past_key_values[i] if past_key_values is not None else None - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - if self.config.add_cross_attention and encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None - ) - - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->LayoutLM -class TFLayoutLMPooler(keras.layers.Layer): - def __init__(self, config: LayoutLMConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->LayoutLM -class TFLayoutLMPredictionHeadTransform(keras.layers.Layer): - def __init__(self, config: LayoutLMConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - - if isinstance(config.hidden_act, str): - self.transform_act_fn = get_tf_activation(config.hidden_act) - else: - self.transform_act_fn = config.hidden_act - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(inputs=hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead with Bert->LayoutLM -class TFLayoutLMLMPredictionHead(keras.layers.Layer): - def __init__(self, config: LayoutLMConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - - self.transform = TFLayoutLMPredictionHeadTransform(config, name="transform") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.input_embeddings = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "transform", None) is not None: - with tf.name_scope(self.transform.name): - self.transform.build(None) - - def get_output_embeddings(self) -> keras.layers.Layer: - return self.input_embeddings - - def set_output_embeddings(self, value: tf.Variable): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self) -> dict[str, tf.Variable]: - return {"bias": self.bias} - - def set_bias(self, value: tf.Variable): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.transform(hidden_states=hidden_states) - seq_length = shape_list(hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->LayoutLM -class TFLayoutLMMLMHead(keras.layers.Layer): - def __init__(self, config: LayoutLMConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.predictions = TFLayoutLMLMPredictionHead(config, input_embeddings, name="predictions") - - def call(self, sequence_output: tf.Tensor) -> tf.Tensor: - prediction_scores = self.predictions(hidden_states=sequence_output) - - return prediction_scores - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - -@keras_serializable -class TFLayoutLMMainLayer(keras.layers.Layer): - config_class = LayoutLMConfig - - def __init__(self, config: LayoutLMConfig, add_pooling_layer: bool = True, **kwargs): - super().__init__(**kwargs) - - self.config = config - - self.embeddings = TFLayoutLMEmbeddings(config, name="embeddings") - self.encoder = TFLayoutLMEncoder(config, name="encoder") - self.pooler = TFLayoutLMPooler(config, name="pooler") if add_pooling_layer else None - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - bbox: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - if bbox is None: - bbox = tf.fill(dims=input_shape + [4], value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - bbox=bbox, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - training=training, - ) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - # Need to pass these required positional arguments to `Encoder` - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=None, - past_key_values=None, - use_cache=False, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -class TFLayoutLMPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = LayoutLMConfig - base_model_prefix = "layoutlm" - - @property - def input_signature(self): - signature = super().input_signature - signature["bbox"] = tf.TensorSpec(shape=(None, None, 4), dtype=tf.int32, name="bbox") - return signature - - -LAYOUTLM_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`LayoutLMConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -LAYOUTLM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - bbox (`Numpy array` or `tf.Tensor` of shape `({0}, 4)`, *optional*): - Bounding Boxes of each input sequence tokens. Selected in the range `[0, config.max_2d_position_embeddings- - 1]`. - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare LayoutLM Model transformer outputting raw hidden-states without any specific head on top.", - LAYOUTLM_START_DOCSTRING, -) -class TFLayoutLMModel(TFLayoutLMPreTrainedModel): - def __init__(self, config: LayoutLMConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm") - - @unpack_inputs - @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings( - output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC - ) - def call( - self, - input_ids: TFModelInputType | None = None, - bbox: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TFLayoutLMModel - >>> import tensorflow as tf - - >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") - >>> model = TFLayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased") - - >>> words = ["Hello", "world"] - >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782] - - >>> token_boxes = [] - >>> for word, box in zip(words, normalized_word_boxes): - ... word_tokens = tokenizer.tokenize(word) - ... token_boxes.extend([box] * len(word_tokens)) - >>> # add bounding boxes of cls + sep tokens - >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]] - - >>> encoding = tokenizer(" ".join(words), return_tensors="tf") - >>> input_ids = encoding["input_ids"] - >>> attention_mask = encoding["attention_mask"] - >>> token_type_ids = encoding["token_type_ids"] - >>> bbox = tf.convert_to_tensor([token_boxes]) - - >>> outputs = model( - ... input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids - ... ) - - >>> last_hidden_states = outputs.last_hidden_state - ```""" - outputs = self.layoutlm( - input_ids=input_ids, - bbox=bbox, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layoutlm", None) is not None: - with tf.name_scope(self.layoutlm.name): - self.layoutlm.build(None) - - -@add_start_docstrings("""LayoutLM Model with a `language modeling` head on top.""", LAYOUTLM_START_DOCSTRING) -class TFLayoutLMForMaskedLM(TFLayoutLMPreTrainedModel, TFMaskedLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"pooler", - r"cls.seq_relationship", - r"cls.predictions.decoder.weight", - r"nsp___cls", - ] - - def __init__(self, config: LayoutLMConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if config.is_decoder: - logger.warning( - "If you want to use `TFLayoutLMForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.layoutlm = TFLayoutLMMainLayer(config, add_pooling_layer=True, name="layoutlm") - self.mlm = TFLayoutLMMLMHead(config, input_embeddings=self.layoutlm.embeddings, name="mlm___cls") - - def get_lm_head(self) -> keras.layers.Layer: - return self.mlm.predictions - - def get_prefix_bias_name(self) -> str: - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - bbox: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TFLayoutLMForMaskedLM - >>> import tensorflow as tf - - >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") - >>> model = TFLayoutLMForMaskedLM.from_pretrained("microsoft/layoutlm-base-uncased") - - >>> words = ["Hello", "[MASK]"] - >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782] - - >>> token_boxes = [] - >>> for word, box in zip(words, normalized_word_boxes): - ... word_tokens = tokenizer.tokenize(word) - ... token_boxes.extend([box] * len(word_tokens)) - >>> # add bounding boxes of cls + sep tokens - >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]] - - >>> encoding = tokenizer(" ".join(words), return_tensors="tf") - >>> input_ids = encoding["input_ids"] - >>> attention_mask = encoding["attention_mask"] - >>> token_type_ids = encoding["token_type_ids"] - >>> bbox = tf.convert_to_tensor([token_boxes]) - - >>> labels = tokenizer("Hello world", return_tensors="tf")["input_ids"] - - >>> outputs = model( - ... input_ids=input_ids, - ... bbox=bbox, - ... attention_mask=attention_mask, - ... token_type_ids=token_type_ids, - ... labels=labels, - ... ) - - >>> loss = outputs.loss - ```""" - outputs = self.layoutlm( - input_ids=input_ids, - bbox=bbox, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.mlm(sequence_output=sequence_output, training=training) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layoutlm", None) is not None: - with tf.name_scope(self.layoutlm.name): - self.layoutlm.build(None) - if getattr(self, "mlm", None) is not None: - with tf.name_scope(self.mlm.name): - self.mlm.build(None) - - -@add_start_docstrings( - """ - LayoutLM Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - LAYOUTLM_START_DOCSTRING, -) -class TFLayoutLMForSequenceClassification(TFLayoutLMPreTrainedModel, TFSequenceClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config: LayoutLMConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - bbox: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TFLayoutLMForSequenceClassification - >>> import tensorflow as tf - - >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") - >>> model = TFLayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased") - - >>> words = ["Hello", "world"] - >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782] - - >>> token_boxes = [] - >>> for word, box in zip(words, normalized_word_boxes): - ... word_tokens = tokenizer.tokenize(word) - ... token_boxes.extend([box] * len(word_tokens)) - >>> # add bounding boxes of cls + sep tokens - >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]] - - >>> encoding = tokenizer(" ".join(words), return_tensors="tf") - >>> input_ids = encoding["input_ids"] - >>> attention_mask = encoding["attention_mask"] - >>> token_type_ids = encoding["token_type_ids"] - >>> bbox = tf.convert_to_tensor([token_boxes]) - >>> sequence_label = tf.convert_to_tensor([1]) - - >>> outputs = model( - ... input_ids=input_ids, - ... bbox=bbox, - ... attention_mask=attention_mask, - ... token_type_ids=token_type_ids, - ... labels=sequence_label, - ... ) - - >>> loss = outputs.loss - >>> logits = outputs.logits - ```""" - outputs = self.layoutlm( - input_ids=input_ids, - bbox=bbox, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=training) - logits = self.classifier(inputs=pooled_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layoutlm", None) is not None: - with tf.name_scope(self.layoutlm.name): - self.layoutlm.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - LAYOUTLM_START_DOCSTRING, -) -class TFLayoutLMForTokenClassification(TFLayoutLMPreTrainedModel, TFTokenClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"pooler", - r"mlm___cls", - r"nsp___cls", - r"cls.predictions", - r"cls.seq_relationship", - ] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config: LayoutLMConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.layoutlm = TFLayoutLMMainLayer(config, add_pooling_layer=True, name="layoutlm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFTokenClassifierOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - bbox: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - - Returns: - - Examples: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFLayoutLMForTokenClassification - - >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") - >>> model = TFLayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased") - - >>> words = ["Hello", "world"] - >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782] - - >>> token_boxes = [] - >>> for word, box in zip(words, normalized_word_boxes): - ... word_tokens = tokenizer.tokenize(word) - ... token_boxes.extend([box] * len(word_tokens)) - >>> # add bounding boxes of cls + sep tokens - >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]] - - >>> encoding = tokenizer(" ".join(words), return_tensors="tf") - >>> input_ids = encoding["input_ids"] - >>> attention_mask = encoding["attention_mask"] - >>> token_type_ids = encoding["token_type_ids"] - >>> bbox = tf.convert_to_tensor([token_boxes]) - >>> token_labels = tf.convert_to_tensor([1, 1, 0, 0]) - - >>> outputs = model( - ... input_ids=input_ids, - ... bbox=bbox, - ... attention_mask=attention_mask, - ... token_type_ids=token_type_ids, - ... labels=token_labels, - ... ) - - >>> loss = outputs.loss - >>> logits = outputs.logits - ```""" - outputs = self.layoutlm( - input_ids=input_ids, - bbox=bbox, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(inputs=sequence_output, training=training) - logits = self.classifier(inputs=sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layoutlm", None) is not None: - with tf.name_scope(self.layoutlm.name): - self.layoutlm.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - LayoutLM Model with a span classification head on top for extractive question-answering tasks such as - [DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the final hidden-states output to compute `span - start logits` and `span end logits`). - """, - LAYOUTLM_START_DOCSTRING, -) -class TFLayoutLMForQuestionAnswering(TFLayoutLMPreTrainedModel, TFQuestionAnsweringLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"pooler", - r"mlm___cls", - r"nsp___cls", - r"cls.predictions", - r"cls.seq_relationship", - ] - - def __init__(self, config: LayoutLMConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.layoutlm = TFLayoutLMMainLayer(config, add_pooling_layer=True, name="layoutlm") - self.qa_outputs = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="qa_outputs", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - bbox: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - - Returns: - - Examples: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFLayoutLMForQuestionAnswering - >>> from datasets import load_dataset - - >>> tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True) - >>> model = TFLayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="1e3ebac") - - >>> dataset = load_dataset("nielsr/funsd", split="train") - >>> example = dataset[0] - >>> question = "what's his name?" - >>> words = example["words"] - >>> boxes = example["bboxes"] - - >>> encoding = tokenizer( - ... question.split(), words, is_split_into_words=True, return_token_type_ids=True, return_tensors="tf" - ... ) - >>> bbox = [] - >>> for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)): - ... if s == 1: - ... bbox.append(boxes[w]) - ... elif i == tokenizer.sep_token_id: - ... bbox.append([1000] * 4) - ... else: - ... bbox.append([0] * 4) - >>> encoding["bbox"] = tf.convert_to_tensor([bbox]) - - >>> word_ids = encoding.word_ids(0) - >>> outputs = model(**encoding) - >>> loss = outputs.loss - >>> start_scores = outputs.start_logits - >>> end_scores = outputs.end_logits - >>> start, end = word_ids[tf.math.argmax(start_scores, -1)[0]], word_ids[tf.math.argmax(end_scores, -1)[0]] - >>> print(" ".join(words[start : end + 1])) - M. Hamann P. Harper, P. Martinez - ```""" - - outputs = self.layoutlm( - input_ids=input_ids, - bbox=bbox, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(inputs=sequence_output) - start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) - start_logits = tf.squeeze(input=start_logits, axis=-1) - end_logits = tf.squeeze(input=end_logits, axis=-1) - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layoutlm", None) is not None: - with tf.name_scope(self.layoutlm.name): - self.layoutlm.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFLayoutLMForMaskedLM", - "TFLayoutLMForSequenceClassification", - "TFLayoutLMForTokenClassification", - "TFLayoutLMForQuestionAnswering", - "TFLayoutLMMainLayer", - "TFLayoutLMModel", - "TFLayoutLMPreTrainedModel", -] diff --git a/src/transformers/models/layoutlmv3/modeling_tf_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_tf_layoutlmv3.py deleted file mode 100644 index c0586d58835e..000000000000 --- a/src/transformers/models/layoutlmv3/modeling_tf_layoutlmv3.py +++ /dev/null @@ -1,1767 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 LayoutLMv3 model.""" - -from __future__ import annotations - -import collections -import math - -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings -from .configuration_layoutlmv3 import LayoutLMv3Config - - -_CONFIG_FOR_DOC = "LayoutLMv3Config" - -_DUMMY_INPUT_IDS = [ - [7, 6, 1], - [1, 2, 0], -] - -_DUMMY_BBOX = [ - [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], - [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]], -] - - -LARGE_NEGATIVE = -1e8 - - -class TFLayoutLMv3PatchEmbeddings(keras.layers.Layer): - """LayoutLMv3 image (patch) embeddings.""" - - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(**kwargs) - patch_sizes = ( - config.patch_size - if isinstance(config.patch_size, collections.abc.Iterable) - else (config.patch_size, config.patch_size) - ) - self.proj = keras.layers.Conv2D( - filters=config.hidden_size, - kernel_size=patch_sizes, - strides=patch_sizes, - padding="valid", - data_format="channels_last", - use_bias=True, - kernel_initializer=get_initializer(config.initializer_range), - name="proj", - ) - self.hidden_size = config.hidden_size - self.num_patches = (config.input_size**2) // (patch_sizes[0] * patch_sizes[1]) - self.config = config - - def call(self, pixel_values: tf.Tensor) -> tf.Tensor: - # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1]) - - embeddings = self.proj(pixel_values) - embeddings = tf.reshape(embeddings, (-1, self.num_patches, self.hidden_size)) - return embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "proj", None) is not None: - with tf.name_scope(self.proj.name): - self.proj.build([None, None, None, self.config.num_channels]) - - -class TFLayoutLMv3TextEmbeddings(keras.layers.Layer): - """ - LayoutLMv3 text embeddings. Same as `RobertaEmbeddings` but with added spatial (layout) embeddings. - """ - - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(**kwargs) - self.word_embeddings = keras.layers.Embedding( - config.vocab_size, - config.hidden_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="word_embeddings", - ) - self.token_type_embeddings = keras.layers.Embedding( - config.type_vocab_size, - config.hidden_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="token_type_embeddings", - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.padding_token_index = config.pad_token_id - self.position_embeddings = keras.layers.Embedding( - config.max_position_embeddings, - config.hidden_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="position_embeddings", - ) - self.x_position_embeddings = keras.layers.Embedding( - config.max_2d_position_embeddings, - config.coordinate_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="x_position_embeddings", - ) - self.y_position_embeddings = keras.layers.Embedding( - config.max_2d_position_embeddings, - config.coordinate_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="y_position_embeddings", - ) - self.h_position_embeddings = keras.layers.Embedding( - config.max_2d_position_embeddings, - config.shape_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="h_position_embeddings", - ) - self.w_position_embeddings = keras.layers.Embedding( - config.max_2d_position_embeddings, - config.shape_size, - embeddings_initializer=get_initializer(config.initializer_range), - name="w_position_embeddings", - ) - self.max_2d_positions = config.max_2d_position_embeddings - self.config = config - - def calculate_spatial_position_embeddings(self, bbox: tf.Tensor) -> tf.Tensor: - try: - left_position_ids = bbox[:, :, 0] - upper_position_ids = bbox[:, :, 1] - right_position_ids = bbox[:, :, 2] - lower_position_ids = bbox[:, :, 3] - except IndexError as exception: - raise IndexError("Bounding box is not of shape (batch_size, seq_length, 4).") from exception - - try: - left_position_embeddings = self.x_position_embeddings(left_position_ids) - upper_position_embeddings = self.y_position_embeddings(upper_position_ids) - right_position_embeddings = self.x_position_embeddings(right_position_ids) - lower_position_embeddings = self.y_position_embeddings(lower_position_ids) - except IndexError as exception: - raise IndexError( - f"The `bbox` coordinate values should be within 0-{self.max_2d_positions} range." - ) from exception - - max_position_id = self.max_2d_positions - 1 - h_position_embeddings = self.h_position_embeddings( - tf.clip_by_value(bbox[:, :, 3] - bbox[:, :, 1], 0, max_position_id) - ) - w_position_embeddings = self.w_position_embeddings( - tf.clip_by_value(bbox[:, :, 2] - bbox[:, :, 0], 0, max_position_id) - ) - - # LayoutLMv1 sums the spatial embeddings, but LayoutLMv3 concatenates them. - spatial_position_embeddings = tf.concat( - [ - left_position_embeddings, - upper_position_embeddings, - right_position_embeddings, - lower_position_embeddings, - h_position_embeddings, - w_position_embeddings, - ], - axis=-1, - ) - return spatial_position_embeddings - - def create_position_ids_from_inputs_embeds(self, inputs_embds: tf.Tensor) -> tf.Tensor: - """ - We are provided embeddings directly. We cannot infer which are padded, so just generate sequential position - ids. - """ - input_shape = tf.shape(inputs_embds) - sequence_length = input_shape[1] - start_index = self.padding_token_index + 1 - end_index = self.padding_token_index + sequence_length + 1 - position_ids = tf.range(start_index, end_index, dtype=tf.int32) - batch_size = input_shape[0] - position_ids = tf.reshape(position_ids, (1, sequence_length)) - position_ids = tf.tile(position_ids, (batch_size, 1)) - return position_ids - - def create_position_ids_from_input_ids(self, input_ids: tf.Tensor) -> tf.Tensor: - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_token_index + 1. - """ - mask = tf.cast(tf.not_equal(input_ids, self.padding_token_index), input_ids.dtype) - position_ids = tf.cumsum(mask, axis=1) * mask - position_ids = position_ids + self.padding_token_index - return position_ids - - def create_position_ids(self, input_ids: tf.Tensor, inputs_embeds: tf.Tensor) -> tf.Tensor: - if input_ids is None: - return self.create_position_ids_from_inputs_embeds(inputs_embeds) - else: - return self.create_position_ids_from_input_ids(input_ids) - - def call( - self, - input_ids: tf.Tensor | None = None, - bbox: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - training: bool = False, - ) -> tf.Tensor: - if position_ids is None: - position_ids = self.create_position_ids(input_ids, inputs_embeds) - - if input_ids is not None: - input_shape = tf.shape(input_ids) - else: - input_shape = tf.shape(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.zeros(input_shape, dtype=position_ids.dtype) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.word_embeddings.input_dim) - inputs_embeds = self.word_embeddings(input_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = inputs_embeds + token_type_embeddings - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings - - spatial_position_embeddings = self.calculate_spatial_position_embeddings(bbox) - - embeddings += spatial_position_embeddings - - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings, training=training) - return embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "word_embeddings", None) is not None: - with tf.name_scope(self.word_embeddings.name): - self.word_embeddings.build(None) - if getattr(self, "token_type_embeddings", None) is not None: - with tf.name_scope(self.token_type_embeddings.name): - self.token_type_embeddings.build(None) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "position_embeddings", None) is not None: - with tf.name_scope(self.position_embeddings.name): - self.position_embeddings.build(None) - if getattr(self, "x_position_embeddings", None) is not None: - with tf.name_scope(self.x_position_embeddings.name): - self.x_position_embeddings.build(None) - if getattr(self, "y_position_embeddings", None) is not None: - with tf.name_scope(self.y_position_embeddings.name): - self.y_position_embeddings.build(None) - if getattr(self, "h_position_embeddings", None) is not None: - with tf.name_scope(self.h_position_embeddings.name): - self.h_position_embeddings.build(None) - if getattr(self, "w_position_embeddings", None) is not None: - with tf.name_scope(self.w_position_embeddings.name): - self.w_position_embeddings.build(None) - - -class TFLayoutLMv3SelfAttention(keras.layers.Layer): - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(**kwargs) - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.attention_score_normaliser = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="query", - ) - self.key = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="key", - ) - self.value = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="value", - ) - - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.has_relative_attention_bias = config.has_relative_attention_bias - self.has_spatial_attention_bias = config.has_spatial_attention_bias - self.config = config - - def transpose_for_scores(self, x: tf.Tensor): - shape = tf.shape(x) - new_shape = ( - shape[0], # batch_size - shape[1], # seq_length - self.num_attention_heads, - self.attention_head_size, - ) - x = tf.reshape(x, new_shape) - return tf.transpose(x, perm=[0, 2, 1, 3]) # batch_size, num_heads, seq_length, attention_head_size - - def cogview_attention(self, attention_scores: tf.Tensor, alpha: float | int = 32): - """ - https://huggingface.co/papers/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation - (PB-Relax). A replacement of the original keras.layers.Softmax(axis=-1)(attention_scores). Seems the new - attention_probs will result in a slower speed and a little bias. Can use - tf.debugging.assert_near(standard_attention_probs, cogview_attention_probs, atol=1e-08) for comparison. The - smaller atol (e.g., 1e-08), the better. - """ - scaled_attention_scores = attention_scores / alpha - max_value = tf.expand_dims(tf.reduce_max(scaled_attention_scores, axis=-1), axis=-1) - new_attention_scores = (scaled_attention_scores - max_value) * alpha - return tf.math.softmax(new_attention_scores, axis=-1) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None, - head_mask: tf.Tensor | None, - output_attentions: bool, - rel_pos: tf.Tensor | None = None, - rel_2d_pos: tf.Tensor | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | tuple[tf.Tensor, tf.Tensor]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) - - # Take the dot product between "query" and "key" to get the raw attention scores. - normalised_query_layer = query_layer / self.attention_score_normaliser - transposed_key_layer = tf.transpose( - key_layer, perm=[0, 1, 3, 2] - ) # batch_size, num_heads, attention_head_size, seq_length - attention_scores = tf.matmul(normalised_query_layer, transposed_key_layer) - - if self.has_relative_attention_bias and self.has_spatial_attention_bias: - attention_scores += (rel_pos + rel_2d_pos) / self.attention_score_normaliser - elif self.has_relative_attention_bias: - attention_scores += rel_pos / self.attention_score_normaliser - - if attention_mask is not None: - # Apply the attention mask (is precomputed for all layers in TFLayoutLMv3Model call() function) - attention_scores += attention_mask - - # Normalize the attention scores to probabilities. - # Use the trick of CogView paper to stabilize training. - attention_probs = self.cogview_attention(attention_scores) - - attention_probs = self.dropout(attention_probs, training=training) - - # Mask heads if we want to. - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = tf.matmul(attention_probs, value_layer) - context_layer = tf.transpose( - context_layer, perm=[0, 2, 1, 3] - ) # batch_size, seq_length, num_heads, attention_head_size - shape = tf.shape(context_layer) - context_layer = tf.reshape( - context_layer, (shape[0], shape[1], self.all_head_size) - ) # batch_size, seq_length, num_heads * attention_head_size - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -# Copied from models.roberta.modeling_tf_roberta.TFRobertaSelfOutput -class TFLayoutLMv3SelfOutput(keras.layers.Layer): - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFLayoutLMv3Attention(keras.layers.Layer): - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(**kwargs) - self.self_attention = TFLayoutLMv3SelfAttention(config, name="self") - self.self_output = TFLayoutLMv3SelfOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None, - head_mask: tf.Tensor | None, - output_attentions: bool, - rel_pos: tf.Tensor | None = None, - rel_2d_pos: tf.Tensor | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | tuple[tf.Tensor, tf.Tensor]: - self_outputs = self.self_attention( - hidden_states, - attention_mask, - head_mask, - output_attentions, - rel_pos, - rel_2d_pos, - training=training, - ) - attention_output = self.self_output(self_outputs[0], hidden_states, training=training) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "self_output", None) is not None: - with tf.name_scope(self.self_output.name): - self.self_output.build(None) - - -# Copied from models.roberta.modeling_tf_bert.TFRobertaIntermediate -class TFLayoutLMv3Intermediate(keras.layers.Layer): - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from models.roberta.modeling_tf_bert.TFRobertaOutput -class TFLayoutLMv3Output(keras.layers.Layer): - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFLayoutLMv3Layer(keras.layers.Layer): - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(**kwargs) - self.attention = TFLayoutLMv3Attention(config, name="attention") - self.intermediate = TFLayoutLMv3Intermediate(config, name="intermediate") - self.bert_output = TFLayoutLMv3Output(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None, - head_mask: tf.Tensor | None, - output_attentions: bool, - rel_pos: tf.Tensor | None = None, - rel_2d_pos: tf.Tensor | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | tuple[tf.Tensor, tf.Tensor]: - self_attention_outputs = self.attention( - hidden_states, - attention_mask, - head_mask, - output_attentions=output_attentions, - rel_pos=rel_pos, - rel_2d_pos=rel_2d_pos, - training=training, - ) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - intermediate_output = self.intermediate(attention_output) - layer_output = self.bert_output(intermediate_output, attention_output, training=training) - outputs = (layer_output,) + outputs - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - - -class TFLayoutLMv3Encoder(keras.layers.Layer): - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layer = [TFLayoutLMv3Layer(config, name=f"layer.{i}") for i in range(config.num_hidden_layers)] - - self.has_relative_attention_bias = config.has_relative_attention_bias - self.has_spatial_attention_bias = config.has_spatial_attention_bias - - if self.has_relative_attention_bias: - self.rel_pos_bins = config.rel_pos_bins - self.max_rel_pos = config.max_rel_pos - self.rel_pos_bias = keras.layers.Dense( - units=config.num_attention_heads, - kernel_initializer=get_initializer(config.initializer_range), - use_bias=False, - name="rel_pos_bias", - ) - - if self.has_spatial_attention_bias: - self.max_rel_2d_pos = config.max_rel_2d_pos - self.rel_2d_pos_bins = config.rel_2d_pos_bins - self.rel_pos_x_bias = keras.layers.Dense( - units=config.num_attention_heads, - kernel_initializer=get_initializer(config.initializer_range), - use_bias=False, - name="rel_pos_x_bias", - ) - self.rel_pos_y_bias = keras.layers.Dense( - units=config.num_attention_heads, - kernel_initializer=get_initializer(config.initializer_range), - use_bias=False, - name="rel_pos_y_bias", - ) - - def relative_position_bucket(self, relative_positions: tf.Tensor, num_buckets: int, max_distance: int): - # the negative relative positions are assigned to the interval [0, num_buckets / 2] - # we deal with this by assigning absolute relative positions to the interval [0, num_buckets / 2] - # and then offsetting the positive relative positions by num_buckets / 2 at the end - num_buckets = num_buckets // 2 - buckets = tf.abs(relative_positions) - - # half of the buckets are for exact increments in positions - max_exact_buckets = num_buckets // 2 - is_small = buckets < max_exact_buckets - - # the other half of the buckets are for logarithmically bigger bins in positions up to max_distance - buckets_log_ratio = tf.math.log(tf.cast(buckets, tf.float32) / max_exact_buckets) - distance_log_ratio = math.log(max_distance / max_exact_buckets) - buckets_big_offset = ( - buckets_log_ratio / distance_log_ratio * (num_buckets - max_exact_buckets) - ) # scale is [0, num_buckets - max_exact_buckets] - buckets_big = max_exact_buckets + buckets_big_offset # scale is [max_exact_buckets, num_buckets] - buckets_big = tf.cast(buckets_big, buckets.dtype) - buckets_big = tf.minimum(buckets_big, num_buckets - 1) - - return (tf.cast(relative_positions > 0, buckets.dtype) * num_buckets) + tf.where( - is_small, buckets, buckets_big - ) - - def _cal_pos_emb( - self, - dense_layer: keras.layers.Dense, - position_ids: tf.Tensor, - num_buckets: int, - max_distance: int, - ): - rel_pos_matrix = tf.expand_dims(position_ids, axis=-2) - tf.expand_dims(position_ids, axis=-1) - rel_pos = self.relative_position_bucket(rel_pos_matrix, num_buckets, max_distance) - rel_pos_one_hot = tf.one_hot(rel_pos, depth=num_buckets, dtype=self.compute_dtype) - embedding = dense_layer(rel_pos_one_hot) - # batch_size, seq_length, seq_length, num_heads --> batch_size, num_heads, seq_length, seq_length - embedding = tf.transpose(embedding, [0, 3, 1, 2]) - embedding = tf.cast(embedding, dtype=self.compute_dtype) - return embedding - - def _cal_1d_pos_emb(self, position_ids: tf.Tensor): - return self._cal_pos_emb(self.rel_pos_bias, position_ids, self.rel_pos_bins, self.max_rel_pos) - - def _cal_2d_pos_emb(self, bbox: tf.Tensor): - position_coord_x = bbox[:, :, 0] # left - position_coord_y = bbox[:, :, 3] # bottom - rel_pos_x = self._cal_pos_emb( - self.rel_pos_x_bias, - position_coord_x, - self.rel_2d_pos_bins, - self.max_rel_2d_pos, - ) - rel_pos_y = self._cal_pos_emb( - self.rel_pos_y_bias, - position_coord_y, - self.rel_2d_pos_bins, - self.max_rel_2d_pos, - ) - rel_2d_pos = rel_pos_x + rel_pos_y - return rel_2d_pos - - def call( - self, - hidden_states: tf.Tensor, - bbox: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - position_ids: tf.Tensor | None = None, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor] | tuple[tf.Tensor, tf.Tensor] | tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - rel_pos = self._cal_1d_pos_emb(position_ids) if self.has_relative_attention_bias else None - rel_2d_pos = self._cal_2d_pos_emb(bbox) if self.has_spatial_attention_bias else None - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - rel_pos=rel_pos, - rel_2d_pos=rel_2d_pos, - training=training, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if return_dict: - return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - else: - return tuple( - value for value in [hidden_states, all_hidden_states, all_self_attentions] if value is not None - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "rel_pos_bias", None) is not None: - with tf.name_scope(self.rel_pos_bias.name): - self.rel_pos_bias.build([None, None, self.rel_pos_bins]) - if getattr(self, "rel_pos_x_bias", None) is not None: - with tf.name_scope(self.rel_pos_x_bias.name): - self.rel_pos_x_bias.build([None, None, self.rel_2d_pos_bins]) - if getattr(self, "rel_pos_y_bias", None) is not None: - with tf.name_scope(self.rel_pos_y_bias.name): - self.rel_pos_y_bias.build([None, None, self.rel_2d_pos_bins]) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFLayoutLMv3MainLayer(keras.layers.Layer): - config_class = LayoutLMv3Config - - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(**kwargs) - - self.config = config - - if config.text_embed: - self.embeddings = TFLayoutLMv3TextEmbeddings(config, name="embeddings") - - if config.visual_embed: - self.patch_embed = TFLayoutLMv3PatchEmbeddings(config, name="patch_embed") - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout") - - if config.has_relative_attention_bias or config.has_spatial_attention_bias: - image_size = config.input_size // config.patch_size - self.init_visual_bbox(image_size=(image_size, image_size)) - - self.norm = keras.layers.LayerNormalization(epsilon=1e-6, name="norm") - - self.encoder = TFLayoutLMv3Encoder(config, name="encoder") - - def build(self, input_shape=None): - if self.config.visual_embed: - image_size = self.config.input_size // self.config.patch_size - self.cls_token = self.add_weight( - shape=(1, 1, self.config.hidden_size), - initializer="zeros", - trainable=True, - dtype=tf.float32, - name="cls_token", - ) - self.pos_embed = self.add_weight( - shape=(1, image_size * image_size + 1, self.config.hidden_size), - initializer="zeros", - trainable=True, - dtype=tf.float32, - name="pos_embed", - ) - - if self.built: - return - self.built = True - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "patch_embed", None) is not None: - with tf.name_scope(self.patch_embed.name): - self.patch_embed.build(None) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "norm", None) is not None: - with tf.name_scope(self.norm.name): - self.norm.build([None, None, self.config.hidden_size]) - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.word_embeddings.weight = value - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - def init_visual_bbox(self, image_size: tuple[int, int], max_len: int = 1000): - # We should not hardcode max_len to 1000, but it is done by the reference implementation, - # so we keep it for compatibility with the pretrained weights. The more correct approach - # would have been to pass on max_len=config.max_2d_position_embeddings - 1. - height, width = image_size - - visual_bbox_x = tf.range(0, max_len * (width + 1), max_len) // width - visual_bbox_x = tf.expand_dims(visual_bbox_x, axis=0) - visual_bbox_x = tf.tile(visual_bbox_x, [width, 1]) # (width, width + 1) - - visual_bbox_y = tf.range(0, max_len * (height + 1), max_len) // height - visual_bbox_y = tf.expand_dims(visual_bbox_y, axis=1) - visual_bbox_y = tf.tile(visual_bbox_y, [1, height]) # (height + 1, height) - - visual_bbox = tf.stack( - [visual_bbox_x[:, :-1], visual_bbox_y[:-1], visual_bbox_x[:, 1:], visual_bbox_y[1:]], - axis=-1, - ) - visual_bbox = tf.reshape(visual_bbox, [-1, 4]) - - cls_token_box = tf.constant([[1, 1, max_len - 1, max_len - 1]], dtype=tf.int32) - self.visual_bbox = tf.concat([cls_token_box, visual_bbox], axis=0) - - def calculate_visual_bbox(self, batch_size: int, dtype: tf.DType): - visual_bbox = tf.expand_dims(self.visual_bbox, axis=0) - visual_bbox = tf.tile(visual_bbox, [batch_size, 1, 1]) - visual_bbox = tf.cast(visual_bbox, dtype=dtype) - return visual_bbox - - def embed_image(self, pixel_values: tf.Tensor) -> tf.Tensor: - embeddings = self.patch_embed(pixel_values) - - # add [CLS] token - batch_size = tf.shape(embeddings)[0] - cls_tokens = tf.tile(self.cls_token, [batch_size, 1, 1]) - embeddings = tf.concat([cls_tokens, embeddings], axis=1) - - # add position embeddings - if getattr(self, "pos_embed", None) is not None: - embeddings += self.pos_embed - - embeddings = self.norm(embeddings) - return embeddings - - def get_extended_attention_mask(self, attention_mask: tf.Tensor) -> tf.Tensor: - # Adapted from transformers.modelling_utils.ModuleUtilsMixin.get_extended_attention_mask - - n_dims = len(attention_mask.shape) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if n_dims == 3: - extended_attention_mask = tf.expand_dims(attention_mask, axis=1) - elif n_dims == 2: - # Provided a padding mask of dimensions [batch_size, seq_length]. - # Make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]. - extended_attention_mask = tf.expand_dims(attention_mask, axis=1) # (batch_size, 1, seq_length) - extended_attention_mask = tf.expand_dims(extended_attention_mask, axis=1) # (batch_size, 1, 1, seq_length) - else: - raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape}).") - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, self.compute_dtype) - extended_attention_mask = (1.0 - extended_attention_mask) * LARGE_NEGATIVE - - return extended_attention_mask - - def get_head_mask(self, head_mask: tf.Tensor | None) -> tf.Tensor | list[tf.Tensor | None]: - if head_mask is None: - return [None] * self.config.num_hidden_layers - - n_dims = tf.rank(head_mask) - if n_dims == 1: - # Gets a tensor with masks for each head (H). - head_mask = tf.expand_dims(head_mask, axis=0) # 1, num_heads - head_mask = tf.expand_dims(head_mask, axis=0) # 1, 1, num_heads - head_mask = tf.expand_dims(head_mask, axis=-1) # 1, 1, num_heads, 1 - head_mask = tf.expand_dims(head_mask, axis=-1) # 1, 1, num_heads, 1, 1 - head_mask = tf.tile( - head_mask, [self.config.num_hidden_layers, 1, 1, 1, 1] - ) # seq_length, 1, num_heads, 1, 1 - elif n_dims == 2: - # Gets a tensor with masks for each layer (L) and head (H). - head_mask = tf.expand_dims(head_mask, axis=1) # seq_length, 1, num_heads - head_mask = tf.expand_dims(head_mask, axis=-1) # seq_length, 1, num_heads, 1 - head_mask = tf.expand_dims(head_mask, axis=-1) # seq_length, 1, num_heads, 1, 1 - elif n_dims != 5: - raise ValueError(f"Wrong shape for head_mask (shape {head_mask.shape}).") - assert tf.rank(head_mask) == 5, f"Got head_mask rank of {tf.rank(head_mask)}, but require 5." - head_mask = tf.cast(head_mask, self.compute_dtype) - return head_mask - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - bbox: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - pixel_values: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor] | tuple[tf.Tensor, tf.Tensor] | tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - # This method can be called with a variety of modalities: - # 1. text + layout - # 2. text + layout + image - # 3. image - # The complexity of this method is mostly just due to handling of these different modalities. - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if input_ids is not None: - input_shape = tf.shape(input_ids) - batch_size = input_shape[0] - seq_length = input_shape[1] - elif inputs_embeds is not None: - input_shape = tf.shape(inputs_embeds) - batch_size = input_shape[0] - seq_length = input_shape[1] - elif pixel_values is not None: - batch_size = tf.shape(pixel_values)[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds or pixel_values") - - # Determine which integer dtype to use. - if input_ids is not None: - int_dtype = input_ids.dtype - elif bbox is not None: - int_dtype = bbox.dtype - elif attention_mask is not None: - int_dtype = attention_mask.dtype - elif token_type_ids is not None: - int_dtype = token_type_ids.dtype - else: - int_dtype = tf.int32 - - if input_ids is not None or inputs_embeds is not None: - if attention_mask is None: - attention_mask = tf.ones((batch_size, seq_length), dtype=int_dtype) - if token_type_ids is None: - token_type_ids = tf.zeros((batch_size, seq_length), dtype=int_dtype) - if bbox is None: - bbox = tf.zeros((batch_size, seq_length, 4), dtype=int_dtype) - - embedding_output = self.embeddings( - input_ids=input_ids, - bbox=bbox, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - training=training, - ) - - final_bbox = None - final_position_ids = None - if pixel_values is not None: - # embed image - visual_embeddings = self.embed_image(pixel_values) - - # calculate attention mask - visual_attention_mask = tf.ones((batch_size, tf.shape(visual_embeddings)[1]), dtype=int_dtype) - if attention_mask is None: - attention_mask = visual_attention_mask - else: - attention_mask = tf.concat([attention_mask, visual_attention_mask], axis=1) - - # calculate bounding boxes - if self.config.has_spatial_attention_bias: - visual_bbox = self.calculate_visual_bbox(batch_size, int_dtype) - if bbox is None: - final_bbox = visual_bbox - else: - final_bbox = tf.concat([bbox, visual_bbox], axis=1) - - # calculate position IDs - if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias: - visual_position_ids = tf.range(0, tf.shape(visual_embeddings)[1], dtype=int_dtype) - visual_position_ids = tf.expand_dims(visual_position_ids, axis=0) - visual_position_ids = tf.tile(visual_position_ids, [batch_size, 1]) - - if input_ids is not None or inputs_embeds is not None: - position_ids = tf.expand_dims(tf.range(0, seq_length, dtype=int_dtype), axis=0) - position_ids = tf.tile(position_ids, [batch_size, 1]) - final_position_ids = tf.concat([position_ids, visual_position_ids], axis=1) - else: - final_position_ids = visual_position_ids - - # calculate embeddings - if input_ids is None and inputs_embeds is None: - embedding_output = visual_embeddings - else: - embedding_output = tf.concat([embedding_output, visual_embeddings], axis=1) - embedding_output = self.LayerNorm(embedding_output) - embedding_output = self.dropout(embedding_output, training=training) - - elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias: - if self.config.has_relative_attention_bias: - position_ids = tf.expand_dims(tf.range(0, seq_length, dtype=int_dtype), axis=0) - position_ids = tf.tile(position_ids, [batch_size, 1]) - final_position_ids = position_ids - - if self.config.has_spatial_attention_bias: - final_bbox = bbox - - extended_attention_mask = self.get_extended_attention_mask(attention_mask) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x seq_length x seq_length - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask) - - encoder_outputs = self.encoder( - embedding_output, - bbox=final_bbox, - position_ids=final_position_ids, - attention_mask=extended_attention_mask, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = encoder_outputs[0] - - if not return_dict: - return (sequence_output,) + encoder_outputs[1:] - - return TFBaseModelOutput( - last_hidden_state=sequence_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - return TFBaseModelOutput( - last_hidden_state=sequence_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class TFLayoutLMv3PreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = LayoutLMv3Config - base_model_prefix = "layoutlmv3" - - @property - def input_signature(self): - sig = super().input_signature - sig["bbox"] = tf.TensorSpec((None, None, 4), tf.int32, name="bbox") - return sig - - -LAYOUTLMV3_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`LayoutLMv3Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -LAYOUTLMV3_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS] - token. See `pixel_values` for `patch_sequence_length`. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - - bbox (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length, 4)`, *optional*): - Bounding boxes of each input sequence tokens. Selected in the range `[0, - config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1) - format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, - y1) represents the position of the lower right corner. - - Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS] - token. See `pixel_values` for `patch_sequence_length`. - - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Batch of document images. Each image is divided into patches of shape `(num_channels, config.patch_size, - config.patch_size)` and the total number of patches (=`patch_sequence_length`) equals to `((height / - config.patch_size) * (width / config.patch_size))`. - - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS] - token. See `pixel_values` for `patch_sequence_length`. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS] - token. See `pixel_values` for `patch_sequence_length`. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS] - token. See `pixel_values` for `patch_sequence_length`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert *input_ids* indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LayoutLMv3 Model transformer outputting raw hidden-states without any specific head on top.", - LAYOUTLMV3_START_DOCSTRING, -) -class TFLayoutLMv3Model(TFLayoutLMv3PreTrainedModel): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"position_ids"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.layoutlmv3 = TFLayoutLMv3MainLayer(config, name="layoutlmv3") - - @unpack_inputs - @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: tf.Tensor | None = None, - bbox: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - pixel_values: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor] | tuple[tf.Tensor, tf.Tensor] | tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoProcessor, TFAutoModel - >>> from datasets import load_dataset - - >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) - >>> model = TFAutoModel.from_pretrained("microsoft/layoutlmv3-base") - - >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") - >>> example = dataset[0] - >>> image = example["image"] - >>> words = example["tokens"] - >>> boxes = example["bboxes"] - - >>> encoding = processor(image, words, boxes=boxes, return_tensors="tf") - - >>> outputs = model(**encoding) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - - outputs = self.layoutlmv3( - input_ids=input_ids, - bbox=bbox, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layoutlmv3", None) is not None: - with tf.name_scope(self.layoutlmv3.name): - self.layoutlmv3.build(None) - - -class TFLayoutLMv3ClassificationHead(keras.layers.Layer): - """ - Head for sentence-level classification tasks. Reference: RobertaClassificationHead - """ - - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.hidden_size, - activation="tanh", - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout( - classifier_dropout, - name="dropout", - ) - self.out_proj = keras.layers.Dense( - config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="out_proj", - ) - self.config = config - - def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor: - outputs = self.dropout(inputs, training=training) - outputs = self.dense(outputs) - outputs = self.dropout(outputs, training=training) - outputs = self.out_proj(outputs) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - LayoutLMv3 Model with a sequence classification head on top (a linear layer on top of the final hidden state of the - [CLS] token) e.g. for document image classification tasks such as the - [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset. - """, - LAYOUTLMV3_START_DOCSTRING, -) -class TFLayoutLMv3ForSequenceClassification(TFLayoutLMv3PreTrainedModel, TFSequenceClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"position_ids"] - - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(config, **kwargs) - self.config = config - self.layoutlmv3 = TFLayoutLMv3MainLayer(config, name="layoutlmv3") - self.classifier = TFLayoutLMv3ClassificationHead(config, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - labels: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - bbox: tf.Tensor | None = None, - pixel_values: tf.Tensor | None = None, - training: bool | None = False, - ) -> ( - TFSequenceClassifierOutput - | tuple[tf.Tensor] - | tuple[tf.Tensor, tf.Tensor] - | tuple[tf.Tensor, tf.Tensor, tf.Tensor] - | tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor] - ): - """ - Returns: - - Examples: - - ```python - >>> from transformers import AutoProcessor, TFAutoModelForSequenceClassification - >>> from datasets import load_dataset - >>> import tensorflow as tf - - >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) - >>> model = TFAutoModelForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base") - - >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") - >>> example = dataset[0] - >>> image = example["image"] - >>> words = example["tokens"] - >>> boxes = example["bboxes"] - - >>> encoding = processor(image, words, boxes=boxes, return_tensors="tf") - >>> sequence_label = tf.convert_to_tensor([1]) - - >>> outputs = model(**encoding, labels=sequence_label) - >>> loss = outputs.loss - >>> logits = outputs.logits - ```""" - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.layoutlmv3( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - bbox=bbox, - pixel_values=pixel_values, - training=training, - ) - sequence_output = outputs[0][:, 0, :] - logits = self.classifier(sequence_output, training=training) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layoutlmv3", None) is not None: - with tf.name_scope(self.layoutlmv3.name): - self.layoutlmv3.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - LayoutLMv3 Model with a token classification head on top (a linear layer on top of the final hidden states) e.g. - for sequence labeling (information extraction) tasks such as [FUNSD](https://guillaumejaume.github.io/FUNSD/), - [SROIE](https://rrc.cvc.uab.es/?ch=13), [CORD](https://github.com/clovaai/cord) and - [Kleister-NDA](https://github.com/applicaai/kleister-nda). - """, - LAYOUTLMV3_START_DOCSTRING, -) -class TFLayoutLMv3ForTokenClassification(TFLayoutLMv3PreTrainedModel, TFTokenClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"position_ids"] - - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(config, **kwargs) - self.num_labels = config.num_labels - - self.layoutlmv3 = TFLayoutLMv3MainLayer(config, name="layoutlmv3") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout") - if config.num_labels < 10: - self.classifier = keras.layers.Dense( - config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - ) - else: - self.classifier = TFLayoutLMv3ClassificationHead(config, name="classifier") - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFTokenClassifierOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: tf.Tensor | None = None, - bbox: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - labels: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - pixel_values: tf.Tensor | None = None, - training: bool | None = False, - ) -> ( - TFTokenClassifierOutput - | tuple[tf.Tensor] - | tuple[tf.Tensor, tf.Tensor] - | tuple[tf.Tensor, tf.Tensor, tf.Tensor] - | tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor] - ): - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - - Returns: - - Examples: - - ```python - >>> from transformers import AutoProcessor, TFAutoModelForTokenClassification - >>> from datasets import load_dataset - - >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) - >>> model = TFAutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=7) - - >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") - >>> example = dataset[0] - >>> image = example["image"] - >>> words = example["tokens"] - >>> boxes = example["bboxes"] - >>> word_labels = example["ner_tags"] - - >>> encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="tf") - - >>> outputs = model(**encoding) - >>> loss = outputs.loss - >>> logits = outputs.logits - ```""" - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.layoutlmv3( - input_ids, - bbox=bbox, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - pixel_values=pixel_values, - training=training, - ) - if input_ids is not None: - input_shape = tf.shape(input_ids) - else: - input_shape = tf.shape(inputs_embeds)[:-1] - - seq_length = input_shape[1] - # only take the text part of the output representations - sequence_output = outputs[0][:, :seq_length] - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layoutlmv3", None) is not None: - with tf.name_scope(self.layoutlmv3.name): - self.layoutlmv3.build(None) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - LayoutLMv3 Model with a span classification head on top for extractive question-answering tasks such as - [DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the text part of the hidden-states output to - compute `span start logits` and `span end logits`). - """, - LAYOUTLMV3_START_DOCSTRING, -) -class TFLayoutLMv3ForQuestionAnswering(TFLayoutLMv3PreTrainedModel, TFQuestionAnsweringLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"position_ids"] - - def __init__(self, config: LayoutLMv3Config, **kwargs): - super().__init__(config, **kwargs) - - self.num_labels = config.num_labels - - self.layoutlmv3 = TFLayoutLMv3MainLayer(config, name="layoutlmv3") - self.qa_outputs = TFLayoutLMv3ClassificationHead(config, name="qa_outputs") - - @unpack_inputs - @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - start_positions: tf.Tensor | None = None, - end_positions: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - bbox: tf.Tensor | None = None, - pixel_values: tf.Tensor | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> ( - TFQuestionAnsweringModelOutput - | tuple[tf.Tensor] - | tuple[tf.Tensor, tf.Tensor] - | tuple[tf.Tensor, tf.Tensor, tf.Tensor] - | tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor] - ): - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - - Returns: - - Examples: - - ```python - >>> from transformers import AutoProcessor, TFAutoModelForQuestionAnswering - >>> from datasets import load_dataset - >>> import tensorflow as tf - - >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) - >>> model = TFAutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base") - - >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") - >>> example = dataset[0] - >>> image = example["image"] - >>> question = "what's his name?" - >>> words = example["tokens"] - >>> boxes = example["bboxes"] - - >>> encoding = processor(image, question, words, boxes=boxes, return_tensors="tf") - >>> start_positions = tf.convert_to_tensor([1]) - >>> end_positions = tf.convert_to_tensor([3]) - - >>> outputs = model(**encoding, start_positions=start_positions, end_positions=end_positions) - >>> loss = outputs.loss - >>> start_scores = outputs.start_logits - >>> end_scores = outputs.end_logits - ```""" - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.layoutlmv3( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - bbox=bbox, - pixel_values=pixel_values, - training=training, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output, training=training) - start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) - start_logits = tf.squeeze(input=start_logits, axis=-1) - end_logits = tf.squeeze(input=end_logits, axis=-1) - - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions, "end_position": end_positions} - loss = self.hf_compute_loss(labels, logits=(start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layoutlmv3", None) is not None: - with tf.name_scope(self.layoutlmv3.name): - self.layoutlmv3.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build(None) - - -__all__ = [ - "TFLayoutLMv3ForQuestionAnswering", - "TFLayoutLMv3ForSequenceClassification", - "TFLayoutLMv3ForTokenClassification", - "TFLayoutLMv3Model", - "TFLayoutLMv3PreTrainedModel", -] diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py deleted file mode 100644 index f499ffac30c9..000000000000 --- a/src/transformers/models/led/modeling_tf_led.py +++ /dev/null @@ -1,2663 +0,0 @@ -# coding=utf-8 -# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 LED model.""" - -from __future__ import annotations - -import random -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutputWithPastAndCrossAttentions - -# Public API -from ...modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_led import LEDConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "allenai/led-base-16384" -_CONFIG_FOR_DOC = "LEDConfig" - - -LARGE_NEGATIVE = -1e8 - - -# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right -def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - pad_token_id = tf.cast(pad_token_id, input_ids.dtype) - decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) - start_tokens = tf.fill( - (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) - ) - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz = input_ids_shape[0] - tgt_len = input_ids_shape[1] - mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE - mask_cond = tf.range(shape_list(mask)[-1]) - - mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) - - return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -class TFLEDLearnedPositionalEmbedding(keras.layers.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): - super().__init__(num_embeddings, embedding_dim, **kwargs) - - def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): - """Input is expected to be of size [bsz x seqlen].""" - seq_len = input_shape[1] - position_ids = tf.range(seq_len, delta=1, name="range") - position_ids += past_key_values_length - - return super().call(tf.cast(position_ids, dtype=tf.int32)) - - -# Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerSelfAttention with TFLongformer->TFLEDEncoder -class TFLEDEncoderSelfAttention(keras.layers.Layer): - def __init__(self, config, layer_id, **kwargs): - super().__init__(**kwargs) - self.config = config - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads}" - ) - - self.num_heads = config.num_attention_heads - self.head_dim = int(config.hidden_size / config.num_attention_heads) - self.embed_dim = config.hidden_size - self.query = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="query", - ) - self.key = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="key", - ) - self.value = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="value", - ) - - # separate projection layers for tokens with global attention - self.query_global = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="query_global", - ) - self.key_global = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="key_global", - ) - self.value_global = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="value_global", - ) - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.global_dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.layer_id = layer_id - attention_window = config.attention_window[self.layer_id] - - assert attention_window % 2 == 0, ( - f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}" - ) - assert attention_window > 0, ( - f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" - ) - - self.one_sided_attn_window_size = attention_window // 2 - - def build(self, input_shape=None): - if not self.built: - with tf.name_scope("query_global"): - self.query_global.build((self.config.hidden_size,)) - with tf.name_scope("key_global"): - self.key_global.build((self.config.hidden_size,)) - with tf.name_scope("value_global"): - self.value_global.build((self.config.hidden_size,)) - - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - if getattr(self, "query_global", None) is not None: - with tf.name_scope(self.query_global.name): - self.query_global.build([None, None, self.config.hidden_size]) - if getattr(self, "key_global", None) is not None: - with tf.name_scope(self.key_global.name): - self.key_global.build([None, None, self.config.hidden_size]) - if getattr(self, "value_global", None) is not None: - with tf.name_scope(self.value_global.name): - self.value_global.build([None, None, self.config.hidden_size]) - - def call( - self, - inputs, - training=False, - ): - """ - LongformerSelfAttention expects *len(hidden_states)* to be multiple of *attention_window*. Padding to - *attention_window* happens in LongformerModel.forward to avoid redoing the padding on each layer. - - The *attention_mask* is changed in [`LongformerModel.forward`] from 0, 1, 2 to: - - - -10000: no attention - - 0: local attention - - +10000: global attention - """ - # retrieve input args - ( - hidden_states, - attention_mask, - layer_head_mask, - is_index_masked, - is_index_global_attn, - is_global_attn, - ) = inputs - - # project hidden states - query_vectors = self.query(hidden_states) - key_vectors = self.key(hidden_states) - value_vectors = self.value(hidden_states) - batch_size, seq_len, embed_dim = shape_list(hidden_states) - - tf.debugging.assert_equal( - embed_dim, - self.embed_dim, - message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}", - ) - - # normalize query - query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype)) - query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) - key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) - - # attn_probs = (batch_size, seq_len, num_heads, window*2+1) - attn_scores = self._sliding_chunks_query_key_matmul( - query_vectors, key_vectors, self.one_sided_attn_window_size - ) - - # values to pad for attention probs - remove_from_windowed_attention_mask = attention_mask != 0 - # cast to fp32/fp16 then replace 1's with -inf - float_mask = tf.cast(remove_from_windowed_attention_mask, dtype=query_vectors.dtype) * LARGE_NEGATIVE - - # diagonal mask with zeros everywhere and -inf inplace of padding - diagonal_mask = self._sliding_chunks_query_key_matmul( - tf.ones(shape_list(attention_mask)), - float_mask, - self.one_sided_attn_window_size, - ) - - # pad local attention probs - attn_scores += diagonal_mask - - tf.debugging.assert_equal( - shape_list(attn_scores), - [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], - message=( - f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}," - f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}" - ), - ) - - # compute global attn indices required through out forward fn - ( - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - ) = self._get_global_attn_indices(is_index_global_attn) - - # this function is only relevant for global attention - if is_global_attn: - attn_scores = self._concat_with_global_key_attn_probs( - attn_scores=attn_scores, - query_vectors=query_vectors, - key_vectors=key_vectors, - max_num_global_attn_indices=max_num_global_attn_indices, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, - ) - - attn_probs = stable_softmax(attn_scores, axis=-1) - - # softmax sometimes inserts NaN if all positions are masked, replace them with 0 - # Make sure to create a mask with the proper shape: - # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] - # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] - if is_global_attn: - masked_index = tf.tile( - is_index_masked[:, :, None, None], - (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), - ) - else: - masked_index = tf.tile( - is_index_masked[:, :, None, None], - (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), - ) - attn_probs = tf.where( - masked_index, - tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype), - attn_probs, - ) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs - - # apply dropout - attn_probs = self.dropout(attn_probs, training=training) - value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) - - # if global attention, compute sum of global and local attn - - if is_global_attn: - attn_output = self._compute_attn_output_with_global_indices( - value_vectors=value_vectors, - attn_probs=attn_probs, - max_num_global_attn_indices=max_num_global_attn_indices, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - ) - else: - attn_output = self._sliding_chunks_matmul_attn_probs_value( - attn_probs, value_vectors, self.one_sided_attn_window_size - ) - - tf.debugging.assert_equal( - shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size" - ) - - attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) - - # compute value for global attention and overwrite to attention output - if is_global_attn: - attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( - attn_output=attn_output, - hidden_states=hidden_states, - max_num_global_attn_indices=max_num_global_attn_indices, - layer_head_mask=layer_head_mask, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, - is_index_masked=is_index_masked, - training=training, - ) - else: - # Leave attn_output unchanged - global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len)) - - # make sure that local attention probabilities are set to 0 for indices of global attn - # Make sure to create a mask with the proper shape: - # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] - # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] - if is_global_attn: - masked_global_attn_index = tf.tile( - is_index_global_attn[:, :, None, None], - (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), - ) - else: - masked_global_attn_index = tf.tile( - is_index_global_attn[:, :, None, None], - (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), - ) - attn_probs = tf.where( - masked_global_attn_index, - tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype), - attn_probs, - ) - - outputs = (attn_output, attn_probs, global_attn_probs) - - return outputs - - def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): - """ - Matrix multiplication of query and key tensors using with a sliding window attention pattern. This - implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an - overlap of size window_overlap - """ - batch_size, seq_len, num_heads, head_dim = shape_list(query) - - tf.debugging.assert_equal( - seq_len % (window_overlap * 2), - 0, - message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", - ) - tf.debugging.assert_equal( - shape_list(query), - shape_list(key), - message=( - f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:" - f" {shape_list(key)}" - ), - ) - - chunks_count = seq_len // window_overlap - 1 - - # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 - query = tf.reshape( - tf.transpose(query, (0, 2, 1, 3)), - (batch_size * num_heads, seq_len, head_dim), - ) - key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) - chunked_query = self._chunk(query, window_overlap) - chunked_key = self._chunk(key, window_overlap) - - # matrix multiplication - # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim - # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim - # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap - chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype) - chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply - - # convert diagonals into columns - paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]]) - diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) - - # allocate space for the overall attention matrix where the chunks are combined. The last dimension - # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to - # window_overlap previous words). The following column is attention score from each word to itself, then - # followed by window_overlap columns for the upper triangle. - - # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions - # - copying the main diagonal and the upper triangle - # TODO: This code is most likely not very efficient and should be improved - diagonal_attn_scores_up_triang = tf.concat( - [ - diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1], - diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1], - ], - axis=1, - ) - - # - copying the lower triangle - diagonal_attn_scores_low_triang = tf.concat( - [ - tf.zeros( - (batch_size * num_heads, 1, window_overlap, window_overlap), - dtype=diagonal_chunked_attention_scores.dtype, - ), - diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :], - ], - axis=1, - ) - diagonal_attn_scores_first_chunk = tf.concat( - [ - tf.roll( - diagonal_chunked_attention_scores, - shift=[1, window_overlap], - axis=[2, 3], - )[:, :, :window_overlap, :window_overlap], - tf.zeros( - (batch_size * num_heads, 1, window_overlap, window_overlap), - dtype=diagonal_chunked_attention_scores.dtype, - ), - ], - axis=1, - ) - first_chunk_mask = ( - tf.tile( - tf.range(chunks_count + 1, dtype=tf.int64)[None, :, None, None], - (batch_size * num_heads, 1, window_overlap, window_overlap), - ) - < 1 - ) - diagonal_attn_scores_low_triang = tf.where( - first_chunk_mask, - diagonal_attn_scores_first_chunk, - diagonal_attn_scores_low_triang, - ) - - # merging upper and lower triangle - diagonal_attention_scores = tf.concat( - [diagonal_attn_scores_low_triang, diagonal_attn_scores_up_triang], axis=-1 - ) - - # separate batch_size and num_heads dimensions again - diagonal_attention_scores = tf.transpose( - tf.reshape( - diagonal_attention_scores, - (batch_size, num_heads, seq_len, 2 * window_overlap + 1), - ), - (0, 2, 1, 3), - ) - - diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap) - - return diagonal_attention_scores - - @staticmethod - def _mask_invalid_locations(input_tensor, window_overlap): - # create correct upper triangle bool mask - mask_2d_upper = tf.reverse( - tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), - axis=[0], - ) - - # pad to full matrix - padding = tf.convert_to_tensor( - [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] - ) - - # create lower mask - mask_2d = tf.pad(mask_2d_upper, padding) - - # combine with upper mask - mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) - - # broadcast to full matrix - mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1)) - - # inf tensor used for masking - inf_tensor = -float("inf") * tf.ones_like(input_tensor) - - # mask - input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor) - - return input_tensor - - def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap): - """ - Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the - same shape as `attn_probs` - """ - - batch_size, seq_len, num_heads, head_dim = shape_list(value) - - tf.debugging.assert_equal( - seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap" - ) - tf.debugging.assert_equal( - shape_list(attn_probs)[:3], - shape_list(value)[:3], - message="value and attn_probs must have same dims (except head_dim)", - ) - tf.debugging.assert_equal( - shape_list(attn_probs)[3], - 2 * window_overlap + 1, - message="attn_probs last dim has to be 2 * window_overlap + 1", - ) - - chunks_count = seq_len // window_overlap - 1 - - # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap - chunked_attn_probs = tf.reshape( - tf.transpose(attn_probs, (0, 2, 1, 3)), - ( - batch_size * num_heads, - seq_len // window_overlap, - window_overlap, - 2 * window_overlap + 1, - ), - ) - - # group batch_size and num_heads dimensions into one - value = tf.reshape( - tf.transpose(value, (0, 2, 1, 3)), - (batch_size * num_heads, seq_len, head_dim), - ) - - # pad seq_len with w at the beginning of the sequence and another window overlap at the end - paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]]) - padded_value = tf.pad(value, paddings, constant_values=-1) - - # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap - frame_size = 3 * window_overlap * head_dim - frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count - chunked_value = tf.signal.frame( - tf.reshape(padded_value, (batch_size * num_heads, -1)), - frame_size, - frame_hop_size, - ) - chunked_value = tf.reshape( - chunked_value, - (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), - ) - - tf.debugging.assert_equal( - shape_list(chunked_value), - [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], - message="Chunked value has the wrong shape", - ) - - chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) - context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) - context = tf.transpose( - tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), - (0, 2, 1, 3), - ) - - return context - - @staticmethod - def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings): - """pads rows and then flips rows and columns""" - hidden_states_padded = tf.pad( - hidden_states_padded, paddings - ) # padding value is not important because it will be overwritten - batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) - hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) - - return hidden_states_padded - - @staticmethod - def _pad_and_diagonalize(chunked_hidden_states): - """ - shift every row 1 step right, converting columns into diagonals. - - Example: - - ```python - chunked_hidden_states: [ - 0.4983, - 2.6918, - -0.0071, - 1.0492, - -1.8348, - 0.7672, - 0.2986, - 0.0285, - -0.7584, - 0.4206, - -0.0405, - 0.1599, - 2.0514, - -1.1600, - 0.5372, - 0.2629, - ] - window_overlap = num_rows = 4 - ``` - - (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 - 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206, - -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] - """ - total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states) - paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]]) - chunked_hidden_states = tf.pad( - chunked_hidden_states, paddings - ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten - chunked_hidden_states = tf.reshape( - chunked_hidden_states, (total_num_heads, num_chunks, -1) - ) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap - chunked_hidden_states = chunked_hidden_states[ - :, :, :-window_overlap - ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap - chunked_hidden_states = tf.reshape( - chunked_hidden_states, - (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), - ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap - chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] - - return chunked_hidden_states - - @staticmethod - def _chunk(hidden_states, window_overlap): - """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" - batch_size, seq_length, hidden_dim = shape_list(hidden_states) - num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1 - - # define frame size and frame stride (similar to convolution) - frame_hop_size = window_overlap * hidden_dim - frame_size = 2 * frame_hop_size - hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim)) - - # chunk with overlap - chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) - - tf.debugging.assert_equal( - shape_list(chunked_hidden_states), - [batch_size, num_output_chunks, frame_size], - message=( - "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension" - f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}." - ), - ) - - chunked_hidden_states = tf.reshape( - chunked_hidden_states, - (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim), - ) - - return chunked_hidden_states - - @staticmethod - def _get_global_attn_indices(is_index_global_attn): - """compute global attn indices required throughout forward pass""" - # helper variable - num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1) - num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype) - - # max number of global attn indices in batch - max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices) - - # indices of global attn - is_index_global_attn_nonzero = tf.where(is_index_global_attn) - - # helper variable - is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims( - num_global_attn_indices, axis=-1 - ) - - # location of the non-padding values within global attention indices - is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn) - - # location of the padding values within global attention indices - is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn)) - - return ( - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - ) - - def _concat_with_global_key_attn_probs( - self, - attn_scores, - key_vectors, - query_vectors, - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - ): - batch_size = shape_list(key_vectors)[0] - - # select global key vectors - global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero) - - # create only global key vectors - key_vectors_only_global = tf.scatter_nd( - is_local_index_global_attn_nonzero, - global_key_vectors, - shape=( - batch_size, - max_num_global_attn_indices, - self.num_heads, - self.head_dim, - ), - ) - - # (batch_size, seq_len, num_heads, max_num_global_attn_indices) - attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global) - - # (batch_size, max_num_global_attn_indices, seq_len, num_heads) - attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2)) - mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( - shape_list(attn_probs_from_global_key_trans)[-2:] - ) - mask = tf.ones(mask_shape) * -10000.0 - mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype) - - # scatter mask - attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( - attn_probs_from_global_key_trans, - is_local_index_no_global_attn_nonzero, - mask, - ) - - # (batch_size, seq_len, num_heads, max_num_global_attn_indices) - attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1)) - - # concat to attn_probs - # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) - attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1) - - return attn_scores - - def _compute_attn_output_with_global_indices( - self, - value_vectors, - attn_probs, - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - ): - batch_size = shape_list(attn_probs)[0] - - # cut local attn probs to global only - attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices] - - # select global value vectors - global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero) - - # create only global value vectors - value_vectors_only_global = tf.scatter_nd( - is_local_index_global_attn_nonzero, - global_value_vectors, - shape=( - batch_size, - max_num_global_attn_indices, - self.num_heads, - self.head_dim, - ), - ) - - # compute attn output only global - attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global) - - # reshape attn probs - attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:] - - # compute attn output with global - attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( - attn_probs_without_global, value_vectors, self.one_sided_attn_window_size - ) - - return attn_output_only_global + attn_output_without_global - - def _compute_global_attn_output_from_hidden( - self, - attn_output, - hidden_states, - max_num_global_attn_indices, - layer_head_mask, - is_local_index_global_attn_nonzero, - is_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - is_index_masked, - training, - ): - batch_size, seq_len = shape_list(hidden_states)[:2] - - # prepare global hidden states - global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero) - global_attn_hidden_states = tf.scatter_nd( - is_local_index_global_attn_nonzero, - global_attn_hidden_states, - shape=(batch_size, max_num_global_attn_indices, self.embed_dim), - ) - - # global key, query, value - global_query_vectors_only_global = self.query_global(global_attn_hidden_states) - global_key_vectors = self.key_global(hidden_states) - global_value_vectors = self.value_global(hidden_states) - - # normalize - global_query_vectors_only_global /= tf.math.sqrt( - tf.cast(self.head_dim, dtype=global_query_vectors_only_global.dtype) - ) - global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) - global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) - global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) - - # compute attn scores - global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(global_attn_scores), - [batch_size * self.num_heads, max_num_global_attn_indices, seq_len], - message=( - "global_attn_scores have the wrong size. Size should be" - f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is" - f" {shape_list(global_attn_scores)}." - ), - ) - - global_attn_scores = tf.reshape( - global_attn_scores, - (batch_size, self.num_heads, max_num_global_attn_indices, seq_len), - ) - global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) - mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( - shape_list(global_attn_scores_trans)[-2:] - ) - global_attn_mask = tf.ones(mask_shape) * -10000.0 - global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype) - - # scatter mask - global_attn_scores_trans = tf.tensor_scatter_nd_update( - global_attn_scores_trans, - is_local_index_no_global_attn_nonzero, - global_attn_mask, - ) - global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) - - # mask global attn scores - attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1)) - global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) - global_attn_scores = tf.reshape( - global_attn_scores, - (batch_size * self.num_heads, max_num_global_attn_indices, seq_len), - ) - - # compute global attn probs - global_attn_probs_float = stable_softmax(global_attn_scores, axis=-1) - - # apply layer head masking - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) - ) - global_attn_probs_float = tf.reshape( - global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len) - ) - - # dropout - global_attn_probs = self.global_dropout(global_attn_probs_float, training=training) - - # global attn output - global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) - - tf.debugging.assert_equal( - shape_list(global_attn_output), - [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], - message=( - "global_attn_output tensor has the wrong size. Size should be" - f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is" - f" {shape_list(global_attn_output)}." - ), - ) - - global_attn_output = tf.reshape( - global_attn_output, - (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim), - ) - - # get only non zero global attn output - nonzero_global_attn_output = tf.gather_nd( - tf.transpose(global_attn_output, (0, 2, 1, 3)), - is_local_index_global_attn_nonzero, - ) - nonzero_global_attn_output = tf.reshape( - nonzero_global_attn_output, - (shape_list(is_local_index_global_attn_nonzero)[0], -1), - ) - - # overwrite values with global attention - attn_output = tf.tensor_scatter_nd_update( - attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output - ) - - global_attn_probs = tf.reshape( - global_attn_probs, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) - ) - - return attn_output, global_attn_probs - - def reshape_and_transpose(self, vector, batch_size): - return tf.reshape( - tf.transpose( - tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), - (0, 2, 1, 3), - ), - (batch_size * self.num_heads, -1, self.head_dim), - ) - - -class TFLEDEncoderAttention(keras.layers.Layer): - def __init__(self, config, layer_id, **kwargs): - super().__init__(**kwargs) - self.longformer_self_attn = TFLEDEncoderSelfAttention(config, layer_id=layer_id, name="longformer_self_attn") - self.output_dense = keras.layers.Dense(config.d_model, use_bias=True, name="output") - self.config = config - - def call(self, inputs, training=False): - ( - hidden_states, - attention_mask, - layer_head_mask, - is_index_masked, - is_index_global_attn, - is_global_attn, - ) = inputs - - self_outputs = self.longformer_self_attn( - [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], - training=training, - ) - - attention_output = self.output_dense(self_outputs[0], training=training) - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "longformer_self_attn", None) is not None: - with tf.name_scope(self.longformer_self_attn.name): - self.longformer_self_attn.build(None) - if getattr(self, "output_dense", None) is not None: - with tf.name_scope(self.output_dense.name): - self.output_dense.build([None, None, self.config.d_model]) - - -class TFLEDDecoderAttention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training=False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast( - attention_mask, dtype=attn_weights.dtype - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -class TFLEDEncoderLayer(keras.layers.Layer): - def __init__(self, config: LEDConfig, layer_id: int, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFLEDEncoderAttention(config, layer_id, name="self_attn") - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - layer_head_mask: tf.Tensor, - is_index_masked: tf.Tensor, - is_index_global_attn: tf.Tensor, - is_global_attn: bool, - training=False, - ): - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* - attention_mask (`tf.Tensor`): attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - *(config.encoder_attention_heads,)*. - """ - residual = hidden_states - layer_outputs = self.self_attn( - [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], - training=training, - ) - - hidden_states = layer_outputs[0] - - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) - - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - return (hidden_states,) + layer_outputs[1:] - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.encoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFLEDDecoderLayer(keras.layers.Layer): - def __init__(self, config: LEDConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFLEDDecoderAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - name="self_attn", - is_decoder=True, - ) - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.encoder_attn = TFLEDDecoderAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - name="encoder_attn", - is_decoder=True, - ) - self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") - self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states, - attention_mask: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - encoder_layer_head_mask: tf.Tensor | None = None, - past_key_value: tuple[tf.Tensor] | None = None, - training=False, - ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* - attention_mask (`tf.Tensor`): attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - encoder_hidden_states (`tf.Tensor`): - cross attention input to the layer of shape *(batch, seq_len, embed_dim)* - encoder_attention_mask (`tf.Tensor`): encoder attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - *(config.encoder_attention_heads,)*. - encoder_layer_head_mask (`tf.Tensor`): mask for encoder attention heads in a given layer of - size *(config.encoder_attention_heads,)*. - past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states - """ - residual = hidden_states - - # Self-Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, - past_key_value=cross_attn_past_key_value, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - return ( - hidden_states, - self_attn_weights, - cross_attn_weights, - present_key_value, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "encoder_attn", None) is not None: - with tf.name_scope(self.encoder_attn.name): - self.encoder_attn.build(None) - if getattr(self, "encoder_attn_layer_norm", None) is not None: - with tf.name_scope(self.encoder_attn_layer_norm.name): - self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.decoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFLEDPreTrainedModel(TFPreTrainedModel): - config_class = LEDConfig - base_model_prefix = "led" - - @property - def input_signature(self): - sig = super().input_signature - sig["global_attention_mask"] = tf.TensorSpec((None, None), tf.int32, name="global_attention_mask") - return sig - - -@dataclass -# Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput with TFLongformer->TFLEDEncoder -class TFLEDEncoderBaseModelOutput(ModelOutput): - """ - Base class for Longformer's outputs, with potential hidden states, local and global attentions. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + - attention_window + 1)`, where `x` is the number of tokens with global attention mask. - - Local attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token in the sequence to every token with - global attention (first `x` values) and to every token in the attention window (remaining `attention_window - + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the - remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a - token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding - (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. - If the attention window contains a token with global attention, the attention weight at the corresponding - index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global - attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be - accessed from `global_attentions`. - global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` - is the number of tokens with global attention mask. - - Global attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token with global attention to every token - in the sequence. - """ - - last_hidden_state: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - global_attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFLEDSeq2SeqModelOutput(ModelOutput): - """ - Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential - decoding. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - encoder_global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` - is the number of tokens with global attention mask. - - Global attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token with global attention to every token - in the sequence. - """ - - last_hidden_state: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - decoder_hidden_states: tuple[tf.Tensor, ...] | None = None - decoder_attentions: tuple[tf.Tensor, ...] | None = None - cross_attentions: tuple[tf.Tensor, ...] | None = None - encoder_last_hidden_state: tf.Tensor | None = None - encoder_hidden_states: tuple[tf.Tensor, ...] | None = None - encoder_attentions: tuple[tf.Tensor, ...] | None = None - encoder_global_attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFLEDSeq2SeqLMOutput(ModelOutput): - """ - Base class for sequence-to-sequence language models outputs. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss. - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - encoder_global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` - is the number of tokens with global attention mask. - - Global attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token with global attention to every token - in the sequence. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - decoder_hidden_states: tuple[tf.Tensor, ...] | None = None - decoder_attentions: tuple[tf.Tensor, ...] | None = None - cross_attentions: tuple[tf.Tensor, ...] | None = None - encoder_last_hidden_state: tf.Tensor | None = None - encoder_hidden_states: tuple[tf.Tensor, ...] | None = None - encoder_attentions: tuple[tf.Tensor, ...] | None = None - encoder_global_attentions: tuple[tf.Tensor, ...] | None = None - - -LED_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`LEDConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -LED_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - - LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` - is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tf.Tensor`, *optional*): - hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - of shape `(batch_size, sequence_length, hidden_size)` is a sequence of - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@keras_serializable -class TFLEDEncoder(keras.layers.Layer): - config_class = LEDConfig - """ - Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a - [`TFLEDEncoderLayer`]. - - Args: - config: LEDConfig - """ - - def __init__(self, config: LEDConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.dropout = keras.layers.Dropout(config.dropout) - if config.encoder_layerdrop > 0: - logger.warning("Layerdrop is currently disabled in TFLED models.") - self.layerdrop = 0.0 - self.padding_idx = config.pad_token_id - - if isinstance(config.attention_window, int): - assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" - assert config.attention_window > 0, "`config.attention_window` has to be positive" - config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer - else: - assert len(config.attention_window) == config.num_hidden_layers, ( - "`len(config.attention_window)` should equal `config.num_hidden_layers`. " - f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" - ) - - self.attention_window = config.attention_window - self.embed_tokens = embed_tokens - self.embed_positions = TFLEDLearnedPositionalEmbedding( - config.max_encoder_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)] - self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") - self.embed_dim = config.d_model - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - global_attention_mask=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - """ - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(input_shape, 1) - - # merge `global_attention_mask` and `attention_mask` - if global_attention_mask is not None: - attention_mask = attention_mask * tf.cast((global_attention_mask + 1), dtype=attention_mask.dtype) - - padding_len, input_ids, attention_mask, inputs_embeds = self._pad_to_window_size( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - pad_token_id=self.padding_idx, - ) - - input_shape = shape_list(attention_mask) - # is index masked or global attention - is_index_masked = tf.math.less(tf.cast(attention_mask, tf.int8), 1) - is_index_global_attn = tf.math.greater(tf.cast(attention_mask, tf.int8), 1) - is_global_attn = tf.math.reduce_any(is_index_global_attn) - - embed_pos = self.embed_positions(input_shape) - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - - # check attention mask and invert - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask)[:, 0, 0, :] - attention_mask = attention_mask[:, :, None, None] - - encoder_states = () if output_hidden_states else None - all_attentions = all_global_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - tf.debugging.assert_equal( - shape_list(head_mask)[0], - len(self.layers), - message=( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(head_mask)[0]}." - ), - ) - - # encoder layers - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len) - encoder_states = encoder_states + (hidden_states_to_add,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): # skip the layer - continue - - layer_outputs = encoder_layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - is_index_masked=is_index_masked, - is_index_global_attn=is_index_global_attn, - is_global_attn=is_global_attn, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) - all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) - - # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn - all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),) - - # undo padding - # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) - hidden_states = self.compute_hidden_states(hidden_states, padding_len) - - # undo padding - if output_attentions: - all_attentions = ( - tuple(state[:, :, :-padding_len, :] for state in all_attentions) if padding_len > 0 else all_attentions - ) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFLEDEncoderBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions, - global_attentions=all_global_attentions, - ) - - @tf.function - def compute_hidden_states(self, hidden_states, padding_len): - return hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states - - def _pad_to_window_size( - self, - input_ids, - attention_mask, - inputs_embeds, - pad_token_id, - ): - """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" - # padding - attention_window = ( - self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window) - ) - - assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" - - input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds) - batch_size, seq_len = input_shape[:2] - padding_len = (attention_window - seq_len % attention_window) % attention_window - - if padding_len > 0: - logger.warning_once( - f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of " - f"`config.attention_window`: {attention_window}" - ) - - paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]]) - - if input_ids is not None: - input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) - - if inputs_embeds is not None: - if padding_len > 0: - input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id) - inputs_embeds_padding = self.embed_tokens(input_ids_padding) - inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) - - attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens - - return ( - padding_len, - input_ids, - attention_mask, - inputs_embeds, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layernorm_embedding", None) is not None: - with tf.name_scope(self.layernorm_embedding.name): - self.layernorm_embedding.build([None, None, self.embed_dim]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFLEDDecoder(keras.layers.Layer): - config_class = LEDConfig - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFLEDDecoderLayer`] - - Args: - config: LEDConfig - embed_tokens: output embedding - """ - - def __init__(self, config: LEDConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.padding_idx = config.pad_token_id - self.embed_tokens = embed_tokens - if config.decoder_layerdrop > 0: - logger.warning("Layerdrop is currently disabled in TFLED models.") - self.layerdrop = 0.0 - self.embed_positions = TFLEDLearnedPositionalEmbedding( - config.max_decoder_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.layers = [TFLEDDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] - self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") - - self.dropout = keras.layers.Dropout(config.dropout) - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - head_mask=None, - encoder_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up - decoding. If `past_key_values` are used, the user can optionally input only the last - `decoder_input_ids` (those that don't have their past key value states given to this model) of shape - `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 - - # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) - - hidden_states = inputs_embeds - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) - else: - combined_attention_mask = _expand_mask( - tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] - ) - - if attention_mask is not None and input_shape[-1] > 1: - combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) - - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) - - hidden_states = self.layernorm_embedding(hidden_states + positions) - hidden_states = self.dropout(hidden_states, training=training) - - # decoder layers - all_hidden_states = () - all_self_attns = () - all_cross_attentions = () - present_key_values = () - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - tf.debugging.assert_equal( - shape_list(head_mask)[0], - len(self.layers), - message=( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(head_mask)[0]}." - ), - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - dropout_probability = random.uniform(0, 1) - - if training and (dropout_probability < self.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None, - past_key_value=past_key_value, - ) - - if use_cache: - present_key_values += (present_key_value,) - - if output_attentions: - all_self_attns += (layer_self_attn,) - all_cross_attentions += (layer_cross_attn,) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - else: - all_hidden_states = None - - all_self_attns = all_self_attns if output_attentions else None - all_cross_attentions = all_cross_attentions if output_attentions else None - - present_key_values = present_key_values if use_cache else None - - if not return_dict: - return tuple( - v - for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions] - if v is not None - ) - else: - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layernorm_embedding", None) is not None: - with tf.name_scope(self.layernorm_embedding.name): - self.layernorm_embedding.build([None, None, self.config.d_model]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFLEDMainLayer(keras.layers.Layer): - config_class = LEDConfig - - def __init__(self, config: LEDConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.shared = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.d_model, - embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), - name="led.shared", - ) - # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) - self.shared.load_weight_prefix = "led.shared" - - self.encoder = TFLEDEncoder(config, self.shared, name="encoder") - self.decoder = TFLEDDecoder(config, self.shared, name="decoder") - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.embed_tokens = self.shared - self.decoder.embed_tokens = self.shared - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - encoder_outputs: tuple | TFLEDEncoderBaseModelOutput | None = None, - global_attention_mask=None, - past_key_values=None, - inputs_embeds=None, - decoder_inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - **kwargs, - ): - if decoder_input_ids is None and decoder_inputs_embeds is None: - use_cache = False - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - global_attention_mask=global_attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, TFLEDEncoderBaseModelOutput): - encoder_outputs = TFLEDEncoderBaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - # If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not return_dict and not isinstance(encoder_outputs, tuple): - encoder_outputs = encoder_outputs.to_tuple() - - decoder_outputs = self.decoder( - decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - encoder_head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return TFLEDSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - encoder_global_attentions=encoder_outputs.global_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - # The shared/tied weights expect to be in the model base namespace - # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than - # the current one. - with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): - self.shared.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - "The bare LED Model outputting raw hidden-states without any specific head on top.", - LED_START_DOCSTRING, -) -class TFLEDModel(TFLEDPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.led = TFLEDMainLayer(config, name="led") - - def get_encoder(self): - return self.led.encoder - - def get_decoder(self): - return self.led.decoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFLEDSeq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - encoder_outputs: tf.Tensor | None = None, - global_attention_mask: tf.Tensor | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, - ) -> tuple[tf.Tensor] | TFLEDSeq2SeqModelOutput: - outputs = self.led( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_outputs, - global_attention_mask=global_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None - - return TFLEDSeq2SeqModelOutput( - last_hidden_state=output.last_hidden_state, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - encoder_global_attentions=enc_g_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "led", None) is not None: - with tf.name_scope(self.led.name): - self.led.build(None) - - -# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer -class BiasLayer(keras.layers.Layer): - """ - Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, - so all weights have to be registered in a layer. - """ - - def __init__(self, shape, initializer, trainable, name, **kwargs): - super().__init__(name=name, **kwargs) - # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of - # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: - # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 - self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) - - def call(self, x): - return x + self.bias - - -@add_start_docstrings( - "The LED Model with a language modeling head. Can be used for summarization.", - LED_START_DOCSTRING, -) -class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [ - r"led.encoder.embed_tokens.weight", - r"led.decoder.embed_tokens.weight", - ] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.led = TFLEDMainLayer(config, name="led") - self.use_cache = config.use_cache - # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False - ) - - # TODO (Joao): investigate why LED has numerical issues in XLA generate - self.supports_xla_generation = False - - def get_decoder(self): - return self.led.decoder - - def get_encoder(self): - return self.led.encoder - - def get_bias(self): - return {"final_logits_bias": self.bias_layer.bias} - - def set_bias(self, value): - # Replaces the existing layers containing bias for correct (de)serialization. - vocab_size = value["final_logits_bias"].shape[-1] - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False - ) - self.bias_layer.bias.assign(value["final_logits_bias"]) - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - @unpack_inputs - @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: TFLEDEncoderBaseModelOutput | None = None, - global_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFLEDSeq2SeqLMOutput: - """ - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TFLEDForConditionalGeneration - >>> import tensorflow as tf - - >>> mname = "allenai/led-base-16384" - >>> tokenizer = AutoTokenizer.from_pretrained(mname) - >>> TXT = "My friends are but they eat too many carbs." - >>> model = TFLEDForConditionalGeneration.from_pretrained(mname) - >>> batch = tokenizer([TXT], return_tensors="tf") - >>> logits = model(inputs=batch.input_ids).logits - >>> probs = tf.nn.softmax(logits[0]) - >>> # probs[5] is associated with the mask token - ```""" - - if labels is not None: - use_cache = False - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.led( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_outputs, - global_attention_mask=global_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - lm_logits = tf.matmul(outputs[0], self.led.shared.weights, transpose_b=True) - lm_logits = self.bias_layer(lm_logits) - masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - return TFLEDSeq2SeqLMOutput( - loss=masked_lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, # index 1 of d outputs - decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs - decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs - cross_attentions=outputs.cross_attentions, # index 4 of d outputs - encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs - encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out - encoder_attentions=outputs.encoder_attentions, # 2 of e out - encoder_global_attentions=outputs.encoder_global_attentions, - ) - - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None - - return TFLEDSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - encoder_global_attentions=enc_g_attns, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - - def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - - def hf_compute_loss(self, labels, logits): - """CrossEntropyLoss that ignores pad tokens""" - loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) - if self.config.tf_legacy_loss: - melted_labels = tf.reshape(labels, (-1,)) - active_loss = tf.not_equal(melted_labels, self.config.pad_token_id) - reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) - labels = tf.boolean_mask(melted_labels, active_loss) - return loss_fn(labels, reduced_logits) - - # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway - unmasked_loss = loss_fn(tf.nn.relu(labels), logits) - # make sure only non-padding labels affect the loss - loss_mask = tf.cast(labels != self.config.pad_token_id, dtype=unmasked_loss.dtype) - masked_loss = unmasked_loss * loss_mask - reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask) - return tf.reshape(reduced_masked_loss, (1,)) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "led", None) is not None: - with tf.name_scope(self.led.name): - self.led.build(None) - if getattr(self, "bias_layer", None) is not None: - with tf.name_scope(self.bias_layer.name): - self.bias_layer.build(None) - - -__all__ = ["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"] diff --git a/src/transformers/models/llama/modeling_flax_llama.py b/src/transformers/models/llama/modeling_flax_llama.py deleted file mode 100644 index 63e34e996ade..000000000000 --- a/src/transformers/models/llama/modeling_flax_llama.py +++ /dev/null @@ -1,747 +0,0 @@ -# coding=utf-8 -# Copyright 2023 Meta AI, EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax LLaMA model.""" - -from functools import partial -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_llama import LlamaConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" -_CHECKPOINT_FOR_DOC = "afmck/testing-llama-tiny" -_REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_llama_3b_v2" - -LLAMA_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`LlamaConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or - `jax.numpy.bfloat16`. - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def create_sinusoidal_positions(num_pos, dim): - inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim)) - freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") - - emb = np.concatenate((freqs, freqs), axis=-1) - out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) - return jnp.array(out[:, :, :num_pos]) - - -def rotate_half(tensor): - """Rotates half the hidden dims of the input.""" - rotate_half_tensor = jnp.concatenate( - (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1 - ) - return rotate_half_tensor - - -def apply_rotary_pos_emb(tensor, sin_pos, cos_pos): - return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos) - - -class FlaxLlamaRMSNorm(nn.Module): - config: LlamaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.epsilon = self.config.rms_norm_eps - self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size) - - def __call__(self, hidden_states): - variance = jnp.asarray(hidden_states, dtype=jnp.float32) - variance = jnp.power(variance, 2) - variance = variance.mean(-1, keepdims=True) - # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt` - hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon) - - return self.weight * jnp.asarray(hidden_states, dtype=self.dtype) - - -class FlaxLlamaRotaryEmbedding(nn.Module): - config: LlamaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - head_dim = self.config.hidden_size // self.config.num_attention_heads - self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim) - - def __call__(self, key, query, position_ids): - sincos = self.sincos[position_ids] - sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1) - - key = apply_rotary_pos_emb(key, sin_pos, cos_pos) - query = apply_rotary_pos_emb(query, sin_pos, cos_pos) - - key = jnp.asarray(key, dtype=self.dtype) - query = jnp.asarray(query, dtype=self.dtype) - - return key, query - - -class FlaxLlamaAttention(nn.Module): - config: LlamaConfig - dtype: jnp.dtype = jnp.float32 - causal: bool = True - is_cross_attention: bool = False - - def setup(self): - config = self.config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 - - dense = partial( - nn.Dense, - use_bias=config.attention_bias, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - self.q_proj = dense(self.num_heads * self.head_dim) - self.k_proj = dense(self.num_key_value_heads * self.head_dim) - self.v_proj = dense(self.num_key_value_heads * self.head_dim) - self.o_proj = dense(self.embed_dim) - self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") - self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=self.dtype) - - def _split_heads(self, hidden_states, num_heads): - return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - attention_mask, - position_ids, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - query = self._split_heads(query, self.num_heads) - key = self._split_heads(key, self.num_key_value_heads) - value = self._split_heads(value, self.num_key_value_heads) - - key, query = self.rotary_emb(key, query, position_ids) - - query_length, key_length = query.shape[1], key.shape[1] - - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - - batch_size = hidden_states.shape[0] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - - dropout_rng = None - if not deterministic and self.config.attention_dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.has_variable("cache", "cached_key") or init_cache: - key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) - - key = jnp.repeat(key, self.num_key_value_groups, axis=2) - value = jnp.repeat(value, self.num_key_value_groups, axis=2) - - # transform boolean mask into float mask - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - - # usual dot product attention - attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype - attn_weights = dot_product_attention_weights( - query, - key, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_dropout, - deterministic=deterministic, - dtype=attention_dtype, - ) - - if self.attention_softmax_in_fp32: - attn_weights = attn_weights.astype(self.dtype) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) - attn_output = self._merge_heads(attn_output) - attn_output = self.o_proj(attn_output) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxLlamaMLP(nn.Module): - config: LlamaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - embed_dim = self.config.hidden_size - inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim - - kernel_init = jax.nn.initializers.normal(self.config.initializer_range) - self.act = ACT2FN[self.config.hidden_act] - - self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) - self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) - self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) - - def __call__(self, hidden_states): - up_proj_states = self.up_proj(hidden_states) - gate_states = self.act(self.gate_proj(hidden_states)) - - hidden_states = self.down_proj(up_proj_states * gate_states) - return hidden_states - - -class FlaxLlamaDecoderLayer(nn.Module): - config: LlamaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.input_layernorm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype) - self.self_attn = FlaxLlamaAttention(self.config, dtype=self.dtype) - self.post_attention_layernorm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype) - self.mlp = FlaxLlamaMLP(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - outputs = self.self_attn( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - # residual connection - attn_output = outputs[0] - hidden_states = residual + attn_output - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - # residual connection - hidden_states = residual + hidden_states - - return (hidden_states,) + outputs[1:] - - -# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Llama, GPT_NEO->LLAMA, transformer->model -class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = LlamaConfig - base_model_prefix = "model" - module_class: nn.Module = None - - def __init__( - self, - config: LlamaConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length)) - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - params: Optional[dict] = None, - past_key_values: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - batch_size, sequence_length = input_ids.shape - - if position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxLlamaAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - not train, - False, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -class FlaxLlamaLayerCollection(nn.Module): - config: LlamaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.blocks = [ - FlaxLlamaDecoderLayer(self.config, dtype=self.dtype, name=str(i)) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = False, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for block in self.blocks: - if output_hidden_states: - all_hidden_states += (hidden_states,) - layer_outputs = block( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - # this contains possible `None` values - `FlaxLlamaModule` will filter them out - outputs = (hidden_states, all_hidden_states, all_attentions) - - return outputs - - -class FlaxLlamaModule(nn.Module): - config: LlamaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.hidden_size = self.config.hidden_size - embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) - self.embed_tokens = nn.Embed( - self.config.vocab_size, - self.hidden_size, - embedding_init=embedding_init, - dtype=self.dtype, - ) - self.layers = FlaxLlamaLayerCollection(self.config, dtype=self.dtype) - self.norm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - deterministic=True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - input_embeds = self.embed_tokens(input_ids.astype("i4")) - - outputs = self.layers( - input_embeds, - position_ids=position_ids, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.norm(hidden_states) - - if output_hidden_states: - all_hidden_states = outputs[1] + (hidden_states,) - outputs = (hidden_states, all_hidden_states) + outputs[2:] - else: - outputs = (hidden_states,) + outputs[1:] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=outputs[1], - attentions=outputs[-1], - ) - - -@add_start_docstrings( - "The bare Llama Model transformer outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class FlaxLlamaModel(FlaxLlamaPreTrainedModel): - module_class = FlaxLlamaModule - - -append_call_sample_docstring( - FlaxLlamaModel, - _CHECKPOINT_FOR_DOC, - FlaxBaseModelOutput, - _CONFIG_FOR_DOC, - real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, -) - - -class FlaxLlamaForCausalLMModule(nn.Module): - config: LlamaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.model = FlaxLlamaModule(self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - outputs = self.model( - input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - - -@add_start_docstrings( - """ - The Llama Model transformer with a language modeling head (linear layer) on top. - """, - LLAMA_START_DOCSTRING, -) -# Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Llama -class FlaxLlamaForCausalLM(FlaxLlamaPreTrainedModel): - module_class = FlaxLlamaForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since Llama uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxLlamaForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutput, - _CONFIG_FOR_DOC, - real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, -) - - -__all__ = ["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"] diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py deleted file mode 100644 index 891f5d76c95c..000000000000 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ /dev/null @@ -1,2783 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tensorflow Longformer model.""" - -from __future__ import annotations - -import warnings -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_longformer import LongformerConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "allenai/longformer-base-4096" -_CONFIG_FOR_DOC = "LongformerConfig" - -LARGE_NEGATIVE = -1e8 - - -@dataclass -class TFLongformerBaseModelOutput(ModelOutput): - """ - Base class for Longformer's outputs, with potential hidden states, local and global attentions. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + - attention_window + 1)`, where `x` is the number of tokens with global attention mask. - - Local attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token in the sequence to every token with - global attention (first `x` values) and to every token in the attention window (remaining `attention_window - + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the - remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a - token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding - (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. - If the attention window contains a token with global attention, the attention weight at the corresponding - index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global - attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be - accessed from `global_attentions`. - global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` - is the number of tokens with global attention mask. - - Global attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token with global attention to every token - in the sequence. - """ - - last_hidden_state: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - global_attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFLongformerBaseModelOutputWithPooling(ModelOutput): - """ - Base class for Longformer's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) further processed by a - Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence - prediction (classification) objective during pretraining. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + - attention_window + 1)`, where `x` is the number of tokens with global attention mask. - - Local attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token in the sequence to every token with - global attention (first `x` values) and to every token in the attention window (remaining `attention_window - + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the - remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a - token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding - (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. - If the attention window contains a token with global attention, the attention weight at the corresponding - index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global - attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be - accessed from `global_attentions`. - global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` - is the number of tokens with global attention mask. - - Global attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token with global attention to every token - in the sequence. - """ - - last_hidden_state: tf.Tensor | None = None - pooler_output: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - global_attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFLongformerMaskedLMOutput(ModelOutput): - """ - Base class for masked language models outputs. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Masked language modeling (MLM) loss. - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + - attention_window + 1)`, where `x` is the number of tokens with global attention mask. - - Local attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token in the sequence to every token with - global attention (first `x` values) and to every token in the attention window (remaining `attention_window - + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the - remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a - token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding - (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. - If the attention window contains a token with global attention, the attention weight at the corresponding - index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global - attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be - accessed from `global_attentions`. - global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` - is the number of tokens with global attention mask. - - Global attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token with global attention to every token - in the sequence. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - global_attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFLongformerQuestionAnsweringModelOutput(ModelOutput): - """ - Base class for outputs of question answering Longformer models. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. - start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Span-start scores (before SoftMax). - end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Span-end scores (before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + - attention_window + 1)`, where `x` is the number of tokens with global attention mask. - - Local attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token in the sequence to every token with - global attention (first `x` values) and to every token in the attention window (remaining `attention_window - + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the - remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a - token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding - (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. - If the attention window contains a token with global attention, the attention weight at the corresponding - index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global - attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be - accessed from `global_attentions`. - global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` - is the number of tokens with global attention mask. - - Global attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token with global attention to every token - in the sequence. - """ - - loss: tf.Tensor | None = None - start_logits: tf.Tensor | None = None - end_logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - global_attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFLongformerSequenceClassifierOutput(ModelOutput): - """ - Base class for outputs of sentence classification models. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + - attention_window + 1)`, where `x` is the number of tokens with global attention mask. - - Local attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token in the sequence to every token with - global attention (first `x` values) and to every token in the attention window (remaining `attention_window - + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the - remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a - token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding - (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. - If the attention window contains a token with global attention, the attention weight at the corresponding - index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global - attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be - accessed from `global_attentions`. - global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` - is the number of tokens with global attention mask. - - Global attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token with global attention to every token - in the sequence. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - global_attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFLongformerMultipleChoiceModelOutput(ModelOutput): - """ - Base class for outputs of multiple choice models. - - Args: - loss (`tf.Tensor` of shape *(1,)*, *optional*, returned when `labels` is provided): - Classification loss. - logits (`tf.Tensor` of shape `(batch_size, num_choices)`): - *num_choices* is the second dimension of the input tensors. (see *input_ids* above). - - Classification scores (before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + - attention_window + 1)`, where `x` is the number of tokens with global attention mask. - - Local attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token in the sequence to every token with - global attention (first `x` values) and to every token in the attention window (remaining `attention_window - + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the - remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a - token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding - (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. - If the attention window contains a token with global attention, the attention weight at the corresponding - index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global - attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be - accessed from `global_attentions`. - global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` - is the number of tokens with global attention mask. - - Global attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token with global attention to every token - in the sequence. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - global_attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFLongformerTokenClassifierOutput(ModelOutput): - """ - Base class for outputs of token classification models. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : - Classification loss. - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`): - Classification scores (before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x + - attention_window + 1)`, where `x` is the number of tokens with global attention mask. - - Local attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token in the sequence to every token with - global attention (first `x` values) and to every token in the attention window (remaining `attention_window - + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the - remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a - token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding - (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens. - If the attention window contains a token with global attention, the attention weight at the corresponding - index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global - attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be - accessed from `global_attentions`. - global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x` - is the number of tokens with global attention mask. - - Global attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. Those are the attention weights from every token with global attention to every token - in the sequence. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - global_attentions: tuple[tf.Tensor, ...] | None = None - - -def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True): - """ - Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is - True` else after `sep_token_id`. - """ - assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions" - question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1][:, None] - # bool attention mask with True in locations of global attention - attention_mask = tf.expand_dims(tf.range(input_ids_shape[1], dtype=tf.int64), axis=0) - attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1)) - if before_sep_token is True: - question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1])) - attention_mask = tf.cast(attention_mask < question_end_index, dtype=question_end_index.dtype) - else: - # last token is separation token and should not be counted and in the middle are two separation tokens - question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1])) - attention_mask = tf.cast( - attention_mask > question_end_index, - dtype=question_end_index.dtype, - ) * tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype) - - return attention_mask - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->Longformer -class TFLongformerLMHead(keras.layers.Layer): - """Longformer Head for masked language modeling.""" - - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.act = get_tf_activation("gelu") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - - def get_output_embeddings(self): - return self.decoder - - def set_output_embeddings(self, value): - self.decoder.weight = value - self.decoder.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.layer_norm(hidden_states) - - # project back to size of vocabulary with bias - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -class TFLongformerEmbeddings(keras.layers.Layer): - """ - Same as BertEmbeddings with a tiny tweak for positional embeddings indexing and some extra casting. - """ - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.padding_idx = 1 - self.config = config - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding - symbols are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - input_ids: tf.Tensor - Returns: tf.Tensor - """ - mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) - incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask - - return incremental_indices + self.padding_idx - - def call( - self, - input_ids=None, - position_ids=None, - token_type_ids=None, - inputs_embeds=None, - past_key_values_length=0, - training=False, - ): - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.cast(tf.fill(dims=input_shape, value=0), tf.int64) - - if position_ids is None: - if input_ids is not None: - # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = self.create_position_ids_from_input_ids( - input_ids=input_ids, past_key_values_length=past_key_values_length - ) - else: - position_ids = tf.expand_dims( - tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1, dtype=tf.int64), - axis=0, - ) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Longformer -class TFLongformerIntermediate(keras.layers.Layer): - def __init__(self, config: LongformerConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Longformer -class TFLongformerOutput(keras.layers.Layer): - def __init__(self, config: LongformerConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Longformer -class TFLongformerPooler(keras.layers.Layer): - def __init__(self, config: LongformerConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Longformer -class TFLongformerSelfOutput(keras.layers.Layer): - def __init__(self, config: LongformerConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFLongformerSelfAttention(keras.layers.Layer): - def __init__(self, config, layer_id, **kwargs): - super().__init__(**kwargs) - self.config = config - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads}" - ) - - self.num_heads = config.num_attention_heads - self.head_dim = int(config.hidden_size / config.num_attention_heads) - self.embed_dim = config.hidden_size - self.query = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="query", - ) - self.key = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="key", - ) - self.value = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="value", - ) - - # separate projection layers for tokens with global attention - self.query_global = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="query_global", - ) - self.key_global = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="key_global", - ) - self.value_global = keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="value_global", - ) - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.global_dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.layer_id = layer_id - attention_window = config.attention_window[self.layer_id] - - assert attention_window % 2 == 0, ( - f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}" - ) - assert attention_window > 0, ( - f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" - ) - - self.one_sided_attn_window_size = attention_window // 2 - - def build(self, input_shape=None): - if not self.built: - with tf.name_scope("query_global"): - self.query_global.build((self.config.hidden_size,)) - with tf.name_scope("key_global"): - self.key_global.build((self.config.hidden_size,)) - with tf.name_scope("value_global"): - self.value_global.build((self.config.hidden_size,)) - - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - if getattr(self, "query_global", None) is not None: - with tf.name_scope(self.query_global.name): - self.query_global.build([None, None, self.config.hidden_size]) - if getattr(self, "key_global", None) is not None: - with tf.name_scope(self.key_global.name): - self.key_global.build([None, None, self.config.hidden_size]) - if getattr(self, "value_global", None) is not None: - with tf.name_scope(self.value_global.name): - self.value_global.build([None, None, self.config.hidden_size]) - - def call( - self, - inputs, - training=False, - ): - """ - LongformerSelfAttention expects *len(hidden_states)* to be multiple of *attention_window*. Padding to - *attention_window* happens in LongformerModel.forward to avoid redoing the padding on each layer. - - The *attention_mask* is changed in [`LongformerModel.forward`] from 0, 1, 2 to: - - - -10000: no attention - - 0: local attention - - +10000: global attention - """ - # retrieve input args - ( - hidden_states, - attention_mask, - layer_head_mask, - is_index_masked, - is_index_global_attn, - is_global_attn, - ) = inputs - - # project hidden states - query_vectors = self.query(hidden_states) - key_vectors = self.key(hidden_states) - value_vectors = self.value(hidden_states) - batch_size, seq_len, embed_dim = shape_list(hidden_states) - - tf.debugging.assert_equal( - embed_dim, - self.embed_dim, - message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}", - ) - - # normalize query - query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype)) - query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) - key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) - - # attn_probs = (batch_size, seq_len, num_heads, window*2+1) - attn_scores = self._sliding_chunks_query_key_matmul( - query_vectors, key_vectors, self.one_sided_attn_window_size - ) - - # values to pad for attention probs - remove_from_windowed_attention_mask = attention_mask != 0 - # cast to fp32/fp16 then replace 1's with -inf - float_mask = tf.cast(remove_from_windowed_attention_mask, dtype=query_vectors.dtype) * LARGE_NEGATIVE - - # diagonal mask with zeros everywhere and -inf inplace of padding - diagonal_mask = self._sliding_chunks_query_key_matmul( - tf.ones(shape_list(attention_mask)), - float_mask, - self.one_sided_attn_window_size, - ) - - # pad local attention probs - attn_scores += diagonal_mask - - tf.debugging.assert_equal( - shape_list(attn_scores), - [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], - message=( - f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}," - f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}" - ), - ) - - # compute global attn indices required through out forward fn - ( - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - ) = self._get_global_attn_indices(is_index_global_attn) - - # this function is only relevant for global attention - if is_global_attn: - attn_scores = self._concat_with_global_key_attn_probs( - attn_scores=attn_scores, - query_vectors=query_vectors, - key_vectors=key_vectors, - max_num_global_attn_indices=max_num_global_attn_indices, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, - ) - - attn_probs = stable_softmax(attn_scores, axis=-1) - - # softmax sometimes inserts NaN if all positions are masked, replace them with 0 - # Make sure to create a mask with the proper shape: - # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] - # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] - if is_global_attn: - masked_index = tf.tile( - is_index_masked[:, :, None, None], - (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), - ) - else: - masked_index = tf.tile( - is_index_masked[:, :, None, None], - (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), - ) - attn_probs = tf.where( - masked_index, - tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype), - attn_probs, - ) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs - - # apply dropout - attn_probs = self.dropout(attn_probs, training=training) - value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) - - # if global attention, compute sum of global and local attn - - if is_global_attn: - attn_output = self._compute_attn_output_with_global_indices( - value_vectors=value_vectors, - attn_probs=attn_probs, - max_num_global_attn_indices=max_num_global_attn_indices, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - ) - else: - attn_output = self._sliding_chunks_matmul_attn_probs_value( - attn_probs, value_vectors, self.one_sided_attn_window_size - ) - - tf.debugging.assert_equal( - shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size" - ) - - attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) - - # compute value for global attention and overwrite to attention output - if is_global_attn: - attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( - attn_output=attn_output, - hidden_states=hidden_states, - max_num_global_attn_indices=max_num_global_attn_indices, - layer_head_mask=layer_head_mask, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, - is_index_masked=is_index_masked, - training=training, - ) - else: - # Leave attn_output unchanged - global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len)) - - # make sure that local attention probabilities are set to 0 for indices of global attn - # Make sure to create a mask with the proper shape: - # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] - # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] - if is_global_attn: - masked_global_attn_index = tf.tile( - is_index_global_attn[:, :, None, None], - (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), - ) - else: - masked_global_attn_index = tf.tile( - is_index_global_attn[:, :, None, None], - (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), - ) - attn_probs = tf.where( - masked_global_attn_index, - tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype), - attn_probs, - ) - - outputs = (attn_output, attn_probs, global_attn_probs) - - return outputs - - def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): - """ - Matrix multiplication of query and key tensors using with a sliding window attention pattern. This - implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an - overlap of size window_overlap - """ - batch_size, seq_len, num_heads, head_dim = shape_list(query) - - tf.debugging.assert_equal( - seq_len % (window_overlap * 2), - 0, - message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", - ) - tf.debugging.assert_equal( - shape_list(query), - shape_list(key), - message=( - f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:" - f" {shape_list(key)}" - ), - ) - - chunks_count = seq_len // window_overlap - 1 - - # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 - query = tf.reshape( - tf.transpose(query, (0, 2, 1, 3)), - (batch_size * num_heads, seq_len, head_dim), - ) - key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) - chunked_query = self._chunk(query, window_overlap) - chunked_key = self._chunk(key, window_overlap) - - # matrix multiplication - # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim - # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim - # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap - chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype) - chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply - - # convert diagonals into columns - paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]]) - diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) - - # allocate space for the overall attention matrix where the chunks are combined. The last dimension - # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to - # window_overlap previous words). The following column is attention score from each word to itself, then - # followed by window_overlap columns for the upper triangle. - - # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions - # - copying the main diagonal and the upper triangle - # TODO: This code is most likely not very efficient and should be improved - diagonal_attn_scores_up_triang = tf.concat( - [ - diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1], - diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1], - ], - axis=1, - ) - - # - copying the lower triangle - diagonal_attn_scores_low_triang = tf.concat( - [ - tf.zeros( - (batch_size * num_heads, 1, window_overlap, window_overlap), - dtype=diagonal_chunked_attention_scores.dtype, - ), - diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :], - ], - axis=1, - ) - diagonal_attn_scores_first_chunk = tf.concat( - [ - tf.roll( - diagonal_chunked_attention_scores, - shift=[1, window_overlap], - axis=[2, 3], - )[:, :, :window_overlap, :window_overlap], - tf.zeros( - (batch_size * num_heads, 1, window_overlap, window_overlap), - dtype=diagonal_chunked_attention_scores.dtype, - ), - ], - axis=1, - ) - first_chunk_mask = ( - tf.tile( - tf.range(chunks_count + 1, dtype=tf.int64)[None, :, None, None], - (batch_size * num_heads, 1, window_overlap, window_overlap), - ) - < 1 - ) - diagonal_attn_scores_low_triang = tf.where( - first_chunk_mask, - diagonal_attn_scores_first_chunk, - diagonal_attn_scores_low_triang, - ) - - # merging upper and lower triangle - diagonal_attention_scores = tf.concat( - [diagonal_attn_scores_low_triang, diagonal_attn_scores_up_triang], axis=-1 - ) - - # separate batch_size and num_heads dimensions again - diagonal_attention_scores = tf.transpose( - tf.reshape( - diagonal_attention_scores, - (batch_size, num_heads, seq_len, 2 * window_overlap + 1), - ), - (0, 2, 1, 3), - ) - - diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap) - - return diagonal_attention_scores - - @staticmethod - def _mask_invalid_locations(input_tensor, window_overlap): - # create correct upper triangle bool mask - mask_2d_upper = tf.reverse( - tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), - axis=[0], - ) - - # pad to full matrix - padding = tf.convert_to_tensor( - [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] - ) - - # create lower mask - mask_2d = tf.pad(mask_2d_upper, padding) - - # combine with upper mask - mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) - - # broadcast to full matrix - mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1)) - - # inf tensor used for masking - inf_tensor = -float("inf") * tf.ones_like(input_tensor) - - # mask - input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor) - - return input_tensor - - def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap): - """ - Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the - same shape as `attn_probs` - """ - - batch_size, seq_len, num_heads, head_dim = shape_list(value) - - tf.debugging.assert_equal( - seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap" - ) - tf.debugging.assert_equal( - shape_list(attn_probs)[:3], - shape_list(value)[:3], - message="value and attn_probs must have same dims (except head_dim)", - ) - tf.debugging.assert_equal( - shape_list(attn_probs)[3], - 2 * window_overlap + 1, - message="attn_probs last dim has to be 2 * window_overlap + 1", - ) - - chunks_count = seq_len // window_overlap - 1 - - # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap - chunked_attn_probs = tf.reshape( - tf.transpose(attn_probs, (0, 2, 1, 3)), - ( - batch_size * num_heads, - seq_len // window_overlap, - window_overlap, - 2 * window_overlap + 1, - ), - ) - - # group batch_size and num_heads dimensions into one - value = tf.reshape( - tf.transpose(value, (0, 2, 1, 3)), - (batch_size * num_heads, seq_len, head_dim), - ) - - # pad seq_len with w at the beginning of the sequence and another window overlap at the end - paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]]) - padded_value = tf.pad(value, paddings, constant_values=-1) - - # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap - frame_size = 3 * window_overlap * head_dim - frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count - chunked_value = tf.signal.frame( - tf.reshape(padded_value, (batch_size * num_heads, -1)), - frame_size, - frame_hop_size, - ) - chunked_value = tf.reshape( - chunked_value, - (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), - ) - - tf.debugging.assert_equal( - shape_list(chunked_value), - [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], - message="Chunked value has the wrong shape", - ) - - chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) - context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) - context = tf.transpose( - tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), - (0, 2, 1, 3), - ) - - return context - - @staticmethod - def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings): - """pads rows and then flips rows and columns""" - hidden_states_padded = tf.pad( - hidden_states_padded, paddings - ) # padding value is not important because it will be overwritten - batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) - hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) - - return hidden_states_padded - - @staticmethod - def _pad_and_diagonalize(chunked_hidden_states): - """ - shift every row 1 step right, converting columns into diagonals. - - Example: - - ```python - chunked_hidden_states: [ - 0.4983, - 2.6918, - -0.0071, - 1.0492, - -1.8348, - 0.7672, - 0.2986, - 0.0285, - -0.7584, - 0.4206, - -0.0405, - 0.1599, - 2.0514, - -1.1600, - 0.5372, - 0.2629, - ] - window_overlap = num_rows = 4 - ``` - - (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 - 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206, - -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] - """ - total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states) - paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]]) - chunked_hidden_states = tf.pad( - chunked_hidden_states, paddings - ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten - chunked_hidden_states = tf.reshape( - chunked_hidden_states, (total_num_heads, num_chunks, -1) - ) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap - chunked_hidden_states = chunked_hidden_states[ - :, :, :-window_overlap - ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap - chunked_hidden_states = tf.reshape( - chunked_hidden_states, - (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), - ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap - chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] - - return chunked_hidden_states - - @staticmethod - def _chunk(hidden_states, window_overlap): - """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" - batch_size, seq_length, hidden_dim = shape_list(hidden_states) - num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1 - - # define frame size and frame stride (similar to convolution) - frame_hop_size = window_overlap * hidden_dim - frame_size = 2 * frame_hop_size - hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim)) - - # chunk with overlap - chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) - - tf.debugging.assert_equal( - shape_list(chunked_hidden_states), - [batch_size, num_output_chunks, frame_size], - message=( - "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension" - f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}." - ), - ) - - chunked_hidden_states = tf.reshape( - chunked_hidden_states, - (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim), - ) - - return chunked_hidden_states - - @staticmethod - def _get_global_attn_indices(is_index_global_attn): - """compute global attn indices required throughout forward pass""" - # helper variable - num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1) - num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype) - - # max number of global attn indices in batch - max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices) - - # indices of global attn - is_index_global_attn_nonzero = tf.where(is_index_global_attn) - - # helper variable - is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims( - num_global_attn_indices, axis=-1 - ) - - # location of the non-padding values within global attention indices - is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn) - - # location of the padding values within global attention indices - is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn)) - - return ( - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - ) - - def _concat_with_global_key_attn_probs( - self, - attn_scores, - key_vectors, - query_vectors, - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - ): - batch_size = shape_list(key_vectors)[0] - - # select global key vectors - global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero) - - # create only global key vectors - key_vectors_only_global = tf.scatter_nd( - is_local_index_global_attn_nonzero, - global_key_vectors, - shape=( - batch_size, - max_num_global_attn_indices, - self.num_heads, - self.head_dim, - ), - ) - - # (batch_size, seq_len, num_heads, max_num_global_attn_indices) - attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global) - - # (batch_size, max_num_global_attn_indices, seq_len, num_heads) - attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2)) - mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( - shape_list(attn_probs_from_global_key_trans)[-2:] - ) - mask = tf.ones(mask_shape) * -10000.0 - mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype) - - # scatter mask - attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( - attn_probs_from_global_key_trans, - is_local_index_no_global_attn_nonzero, - mask, - ) - - # (batch_size, seq_len, num_heads, max_num_global_attn_indices) - attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1)) - - # concat to attn_probs - # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) - attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1) - - return attn_scores - - def _compute_attn_output_with_global_indices( - self, - value_vectors, - attn_probs, - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - ): - batch_size = shape_list(attn_probs)[0] - - # cut local attn probs to global only - attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices] - - # select global value vectors - global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero) - - # create only global value vectors - value_vectors_only_global = tf.scatter_nd( - is_local_index_global_attn_nonzero, - global_value_vectors, - shape=( - batch_size, - max_num_global_attn_indices, - self.num_heads, - self.head_dim, - ), - ) - - # compute attn output only global - attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global) - - # reshape attn probs - attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:] - - # compute attn output with global - attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( - attn_probs_without_global, value_vectors, self.one_sided_attn_window_size - ) - - return attn_output_only_global + attn_output_without_global - - def _compute_global_attn_output_from_hidden( - self, - attn_output, - hidden_states, - max_num_global_attn_indices, - layer_head_mask, - is_local_index_global_attn_nonzero, - is_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - is_index_masked, - training, - ): - batch_size, seq_len = shape_list(hidden_states)[:2] - - # prepare global hidden states - global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero) - global_attn_hidden_states = tf.scatter_nd( - is_local_index_global_attn_nonzero, - global_attn_hidden_states, - shape=(batch_size, max_num_global_attn_indices, self.embed_dim), - ) - - # global key, query, value - global_query_vectors_only_global = self.query_global(global_attn_hidden_states) - global_key_vectors = self.key_global(hidden_states) - global_value_vectors = self.value_global(hidden_states) - - # normalize - global_query_vectors_only_global /= tf.math.sqrt( - tf.cast(self.head_dim, dtype=global_query_vectors_only_global.dtype) - ) - global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) - global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) - global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) - - # compute attn scores - global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(global_attn_scores), - [batch_size * self.num_heads, max_num_global_attn_indices, seq_len], - message=( - "global_attn_scores have the wrong size. Size should be" - f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is" - f" {shape_list(global_attn_scores)}." - ), - ) - - global_attn_scores = tf.reshape( - global_attn_scores, - (batch_size, self.num_heads, max_num_global_attn_indices, seq_len), - ) - global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) - mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( - shape_list(global_attn_scores_trans)[-2:] - ) - global_attn_mask = tf.ones(mask_shape) * -10000.0 - global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype) - - # scatter mask - global_attn_scores_trans = tf.tensor_scatter_nd_update( - global_attn_scores_trans, - is_local_index_no_global_attn_nonzero, - global_attn_mask, - ) - global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) - - # mask global attn scores - attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1)) - global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) - global_attn_scores = tf.reshape( - global_attn_scores, - (batch_size * self.num_heads, max_num_global_attn_indices, seq_len), - ) - - # compute global attn probs - global_attn_probs_float = stable_softmax(global_attn_scores, axis=-1) - - # apply layer head masking - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) - ) - global_attn_probs_float = tf.reshape( - global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len) - ) - - # dropout - global_attn_probs = self.global_dropout(global_attn_probs_float, training=training) - - # global attn output - global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) - - tf.debugging.assert_equal( - shape_list(global_attn_output), - [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], - message=( - "global_attn_output tensor has the wrong size. Size should be" - f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is" - f" {shape_list(global_attn_output)}." - ), - ) - - global_attn_output = tf.reshape( - global_attn_output, - (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim), - ) - - # get only non zero global attn output - nonzero_global_attn_output = tf.gather_nd( - tf.transpose(global_attn_output, (0, 2, 1, 3)), - is_local_index_global_attn_nonzero, - ) - nonzero_global_attn_output = tf.reshape( - nonzero_global_attn_output, - (shape_list(is_local_index_global_attn_nonzero)[0], -1), - ) - - # overwrite values with global attention - attn_output = tf.tensor_scatter_nd_update( - attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output - ) - - global_attn_probs = tf.reshape( - global_attn_probs, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) - ) - - return attn_output, global_attn_probs - - def reshape_and_transpose(self, vector, batch_size): - return tf.reshape( - tf.transpose( - tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), - (0, 2, 1, 3), - ), - (batch_size * self.num_heads, -1, self.head_dim), - ) - - -class TFLongformerAttention(keras.layers.Layer): - def __init__(self, config, layer_id=0, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFLongformerSelfAttention(config, layer_id, name="self") - self.dense_output = TFLongformerSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call(self, inputs, training=False): - ( - hidden_states, - attention_mask, - layer_head_mask, - is_index_masked, - is_index_global_attn, - is_global_attn, - ) = inputs - - self_outputs = self.self_attention( - [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], - training=training, - ) - attention_output = self.dense_output(self_outputs[0], hidden_states, training=training) - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -class TFLongformerLayer(keras.layers.Layer): - def __init__(self, config, layer_id=0, **kwargs): - super().__init__(**kwargs) - - self.attention = TFLongformerAttention(config, layer_id, name="attention") - self.intermediate = TFLongformerIntermediate(config, name="intermediate") - self.longformer_output = TFLongformerOutput(config, name="output") - - def call(self, inputs, training=False): - ( - hidden_states, - attention_mask, - layer_head_mask, - is_index_masked, - is_index_global_attn, - is_global_attn, - ) = inputs - - attention_outputs = self.attention( - [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], - training=training, - ) - attention_output = attention_outputs[0] - intermediate_output = self.intermediate(attention_output) - layer_output = self.longformer_output(intermediate_output, attention_output, training=training) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "longformer_output", None) is not None: - with tf.name_scope(self.longformer_output.name): - self.longformer_output.build(None) - - -class TFLongformerEncoder(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - self.layer = [TFLongformerLayer(config, i, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states, - attention_mask=None, - head_mask=None, - padding_len=0, - is_index_masked=None, - is_index_global_attn=None, - is_global_attn=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - all_hidden_states = () if output_hidden_states else None - all_attentions = all_global_attentions = () if output_attentions else None - - for idx, layer_module in enumerate(self.layer): - if output_hidden_states: - hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states - all_hidden_states = all_hidden_states + (hidden_states_to_add,) - - layer_outputs = layer_module( - [ - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - is_index_masked, - is_index_global_attn, - is_global_attn, - ], - training=training, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) - all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) - - # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn - all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),) - - # Add last layer - if output_hidden_states: - hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states - all_hidden_states = all_hidden_states + (hidden_states_to_add,) - - # undo padding - # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) - hidden_states = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states - if output_attentions: - all_attentions = ( - tuple(state[:, :, :-padding_len, :] for state in all_attentions) if padding_len > 0 else all_attentions - ) - - if not return_dict: - return tuple( - v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None - ) - - return TFLongformerBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - global_attentions=all_global_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFLongformerMainLayer(keras.layers.Layer): - config_class = LongformerConfig - - def __init__(self, config, add_pooling_layer=True, **kwargs): - super().__init__(**kwargs) - - if isinstance(config.attention_window, int): - assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" - assert config.attention_window > 0, "`config.attention_window` has to be positive" - config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer - else: - assert len(config.attention_window) == config.num_hidden_layers, ( - "`len(config.attention_window)` should equal `config.num_hidden_layers`. " - f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" - ) - - self.config = config - self.num_hidden_layers = config.num_hidden_layers - self.initializer_range = config.initializer_range - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - self.pad_token_id = config.pad_token_id - self.attention_window = config.attention_window - self.embeddings = TFLongformerEmbeddings(config, name="embeddings") - self.encoder = TFLongformerEncoder(config, name="encoder") - self.pooler = TFLongformerPooler(config, name="pooler") if add_pooling_layer else None - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, value): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - head_mask=None, - global_attention_mask=None, - token_type_ids=None, - position_ids=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - if input_ids is not None and not isinstance(input_ids, tf.Tensor): - input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64) - elif input_ids is not None: - input_ids = tf.cast(input_ids, tf.int64) - - if attention_mask is not None and not isinstance(attention_mask, tf.Tensor): - attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64) - elif attention_mask is not None: - attention_mask = tf.cast(attention_mask, tf.int64) - - if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor): - global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64) - elif global_attention_mask is not None: - global_attention_mask = tf.cast(global_attention_mask, tf.int64) - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.cast(tf.fill(input_shape, 1), tf.int64) - - if token_type_ids is None: - token_type_ids = tf.cast(tf.fill(input_shape, 0), tf.int64) - - # merge `global_attention_mask` and `attention_mask` - if global_attention_mask is not None: - attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) - - ( - padding_len, - input_ids, - attention_mask, - token_type_ids, - position_ids, - inputs_embeds, - ) = self._pad_to_window_size( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - pad_token_id=self.pad_token_id, - ) - - # is index masked or global attention - is_index_masked = tf.math.less(attention_mask, 1) - is_index_global_attn = tf.math.greater(attention_mask, 1) - is_global_attn = tf.math.reduce_any(is_index_global_attn) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, to_seq_length, 1, 1] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(attention_mask) - extended_attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], attention_mask_shape[1], 1, 1)) - - # Since attention_mask is 1.0 for positions we want to attend locally and 0.0 for - # masked and global attn positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 - embedding_output = self.embeddings( - input_ids, - position_ids, - token_type_ids, - inputs_embeds, - training=training, - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - padding_len=padding_len, - is_index_masked=is_index_masked, - is_index_global_attn=is_index_global_attn, - is_global_attn=is_global_attn, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFLongformerBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - global_attentions=encoder_outputs.global_attentions, - ) - - def _pad_to_window_size( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - inputs_embeds, - pad_token_id, - ): - """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" - # padding - attention_window = ( - self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window) - ) - - assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" - - input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds) - batch_size, seq_len = input_shape[:2] - padding_len = (attention_window - seq_len % attention_window) % attention_window - - paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]]) - - if input_ids is not None: - input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) - - if position_ids is not None: - # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings - position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) - - if inputs_embeds is not None: - if padding_len > 0: - input_ids_padding = tf.cast(tf.fill((batch_size, padding_len), self.pad_token_id), tf.int64) - inputs_embeds_padding = self.embeddings(input_ids_padding) - inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) - - attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens - token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 - - return ( - padding_len, - input_ids, - attention_mask, - token_type_ids, - position_ids, - inputs_embeds, - ) - - @staticmethod - def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor): - # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) - # (global_attention_mask + 1) => 1 for local attention, 2 for global attention - # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention - if attention_mask is not None: - attention_mask = attention_mask * (global_attention_mask + 1) - else: - # simply use `global_attention_mask` as `attention_mask` - # if no `attention_mask` is given - attention_mask = global_attention_mask + 1 - - return attention_mask - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -class TFLongformerPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = LongformerConfig - base_model_prefix = "longformer" - - @property - def input_signature(self): - sig = super().input_signature - sig["global_attention_mask"] = tf.TensorSpec((None, None), tf.int32, name="global_attention_mask") - return sig - - -LONGFORMER_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`LongformerConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -LONGFORMER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`np.ndarray` or `tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - global_attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to decide the attention given on each token, local attention or global attention. Tokens with global - attention attends to all other tokens, and all other tokens attend to them. This is important for - task-specific finetuning because it makes the model more flexible at representing the task. For example, - for classification, the token should be given global attention. For QA, all question tokens should also - have global attention. Please refer to the [Longformer paper](https://huggingface.co/papers/2004.05150) for more - details. Mask values selected in `[0, 1]`: - - - 0 for local attention (a sliding window attention), - - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - - token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare Longformer Model outputting raw hidden-states without any specific head on top.", - LONGFORMER_START_DOCSTRING, -) -class TFLongformerModel(TFLongformerPreTrainedModel): - """ - - This class copies code from [`TFRobertaModel`] and overwrites standard self-attention with longformer - self-attention to provide the ability to process long sequences following the self-attention approach described in - [Longformer: the Long-Document Transformer](https://huggingface.co/papers/2004.05150) by Iz Beltagy, Matthew E. Peters, and - Arman Cohan. Longformer self-attention combines a local (sliding window) and global attention to extend to long - documents without the O(n^2) increase in memory and compute. - - The self-attention module `TFLongformerSelfAttention` implemented here supports the combination of local and global - attention but it lacks support for autoregressive attention and dilated attention. Autoregressive and dilated - attention are more relevant for autoregressive language modeling than finetuning on downstream tasks. Future - release will add support for autoregressive attention, but the support for dilated attention requires a custom CUDA - kernel to be memory and compute efficient. - - """ - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.longformer = TFLongformerMainLayer(config, name="longformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - global_attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFLongformerBaseModelOutputWithPooling | tuple[tf.Tensor]: - outputs = self.longformer( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - global_attention_mask=global_attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "longformer", None) is not None: - with tf.name_scope(self.longformer.name): - self.longformer.build(None) - - -@add_start_docstrings( - """Longformer Model with a `language modeling` head on top.""", - LONGFORMER_START_DOCSTRING, -) -class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer") - self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head") - - def get_lm_head(self): - return self.lm_head - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="allenai/longformer-base-4096", - output_type=TFLongformerMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - mask="", - expected_output="' Paris'", - expected_loss=0.44, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - global_attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFLongformerMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - - outputs = self.longformer( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - global_attention_mask=global_attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output, training=training) - loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - - return ((loss,) + output) if loss is not None else output - - return TFLongformerMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - global_attentions=outputs.global_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "longformer", None) is not None: - with tf.name_scope(self.longformer.name): - self.longformer.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -@add_start_docstrings( - """ - Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / - TriviaQA (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - LONGFORMER_START_DOCSTRING, -) -class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer") - self.qa_outputs = keras.layers.Dense( - config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="qa_outputs", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="allenai/longformer-large-4096-finetuned-triviaqa", - output_type=TFLongformerQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="' puppet'", - expected_loss=0.96, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - global_attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFLongformerQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence - are not taken into account for computing the loss. - """ - - if input_ids is not None and not isinstance(input_ids, tf.Tensor): - input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64) - elif input_ids is not None: - input_ids = tf.cast(input_ids, tf.int64) - - if attention_mask is not None and not isinstance(attention_mask, tf.Tensor): - attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64) - elif attention_mask is not None: - attention_mask = tf.cast(attention_mask, tf.int64) - - if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor): - global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64) - elif global_attention_mask is not None: - global_attention_mask = tf.cast(global_attention_mask, tf.int64) - - # set global attention on question tokens - if global_attention_mask is None and input_ids is not None: - if shape_list(tf.where(input_ids == self.config.sep_token_id))[0] != 3 * shape_list(input_ids)[0]: - logger.warning( - f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for" - " questions answering. You might also consider to set `global_attention_mask` manually in the" - " forward function to avoid this. This is most likely an error. The global attention is disabled" - " for this forward pass." - ) - global_attention_mask = tf.cast(tf.fill(shape_list(input_ids), value=0), tf.int64) - else: - logger.warning_once("Initializing global attention on question tokens...") - # put global attention on all tokens until `config.sep_token_id` is reached - sep_token_indices = tf.where(input_ids == self.config.sep_token_id) - sep_token_indices = tf.cast(sep_token_indices, dtype=tf.int64) - global_attention_mask = _compute_global_attention_mask(shape_list(input_ids), sep_token_indices) - - outputs = self.longformer( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - global_attention_mask=global_attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - - return ((loss,) + output) if loss is not None else output - - return TFLongformerQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - global_attentions=outputs.global_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "longformer", None) is not None: - with tf.name_scope(self.longformer.name): - self.longformer.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -class TFLongformerClassificationHead(keras.layers.Layer): - """Head for sentence-level classification tasks.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.out_proj = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" - ) - self.config = config - - def call(self, hidden_states, training=False): - hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - output = self.out_proj(hidden_states) - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Longformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - LONGFORMER_START_DOCSTRING, -) -class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer") - self.classifier = TFLongformerClassificationHead(config, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFLongformerSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - global_attention_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFLongformerSequenceClassifierOutput | tuple[tf.Tensor]: - if input_ids is not None and not isinstance(input_ids, tf.Tensor): - input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64) - elif input_ids is not None: - input_ids = tf.cast(input_ids, tf.int64) - - if attention_mask is not None and not isinstance(attention_mask, tf.Tensor): - attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64) - elif attention_mask is not None: - attention_mask = tf.cast(attention_mask, tf.int64) - - if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor): - global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64) - elif global_attention_mask is not None: - global_attention_mask = tf.cast(global_attention_mask, tf.int64) - - if global_attention_mask is None and input_ids is not None: - logger.warning_once("Initializing global attention on CLS token...") - # global attention on cls token - global_attention_mask = tf.zeros_like(input_ids) - updates = tf.ones(shape_list(input_ids)[0], dtype=tf.int64) - indices = tf.pad( - tensor=tf.expand_dims(tf.range(shape_list(input_ids)[0], dtype=tf.int64), axis=1), - paddings=[[0, 0], [0, 1]], - constant_values=0, - ) - global_attention_mask = tf.tensor_scatter_nd_update( - global_attention_mask, - indices, - updates, - ) - - outputs = self.longformer( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - global_attention_mask=global_attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFLongformerSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - global_attentions=outputs.global_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "longformer", None) is not None: - with tf.name_scope(self.longformer.name): - self.longformer.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - Longformer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and - a softmax) e.g. for RocStories/SWAG tasks. - """, - LONGFORMER_START_DOCSTRING, -) -class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoiceLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.longformer = TFLongformerMainLayer(config, name="longformer") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @property - def input_signature(self): - return { - "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"), - "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"), - "global_attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="global_attention_mask"), - } - - @unpack_inputs - @add_start_docstrings_to_model_forward( - LONGFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") - ) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFLongformerMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - global_attention_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFLongformerMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - flat_global_attention_mask = ( - tf.reshape(global_attention_mask, (-1, shape_list(global_attention_mask)[-1])) - if global_attention_mask is not None - else None - ) - flat_inputs_embeds = ( - tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - - outputs = self.longformer( - flat_input_ids, - position_ids=flat_position_ids, - token_type_ids=flat_token_type_ids, - attention_mask=flat_attention_mask, - head_mask=head_mask, - global_attention_mask=flat_global_attention_mask, - inputs_embeds=flat_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFLongformerMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - global_attentions=outputs.global_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "longformer", None) is not None: - with tf.name_scope(self.longformer.name): - self.longformer.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Longformer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. - for Named-Entity-Recognition (NER) tasks. - """, - LONGFORMER_START_DOCSTRING, -) -class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.longformer = TFLongformerMainLayer(config=config, add_pooling_layer=False, name="longformer") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFLongformerTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - global_attention_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.array | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFLongformerTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - - outputs = self.longformer( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - global_attention_mask=global_attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFLongformerTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - global_attentions=outputs.global_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "longformer", None) is not None: - with tf.name_scope(self.longformer.name): - self.longformer.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFLongformerForMaskedLM", - "TFLongformerForMultipleChoice", - "TFLongformerForQuestionAnswering", - "TFLongformerForSequenceClassification", - "TFLongformerForTokenClassification", - "TFLongformerModel", - "TFLongformerPreTrainedModel", - "TFLongformerSelfAttention", -] diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py deleted file mode 100644 index dee4afeadf72..000000000000 --- a/src/transformers/models/longt5/modeling_flax_longt5.py +++ /dev/null @@ -1,2449 +0,0 @@ -# coding=utf-8 -# Copyright 2022 LongT5 Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax LongT5 model.""" - -import copy -from typing import Any, Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen import partitioning as nn_partitioning -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax.random import PRNGKey - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxSeq2SeqLMOutput, - FlaxSeq2SeqModelOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_longt5 import LongT5Config - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "google/long-t5-local-base" -_CONFIG_FOR_DOC = "LongT5Config" - -remat = nn_partitioning.remat - - -# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: - """ - Shift input ids one token to the right. - """ - shifted_input_ids = jnp.zeros_like(input_ids) - shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) - shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) - - shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) - return shifted_input_ids - - -def _pad_to_multiple(x: jnp.ndarray, block_len: int, axis: int, pad_value: int = 0) -> jnp.ndarray: - """Pad an array so that a sequence length will be a multiple of `block_len`""" - pad_len = -x.shape[axis] % block_len - pad = [(0, 0)] * x.ndim - pad[axis] = (0, pad_len) - x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value) - return x - - -def _split_into_blocks(x: jnp.ndarray, block_len: int, axis: int) -> jnp.ndarray: - """Split an input array into blocks of a given `block_len` along the given `axis`. If the dimension length - is not a multiple of `block_len`, it will be padded first with selected `pad_value`. - """ - # pad tensor to multiple of block_len - if x.shape[axis] % block_len != 0: - x = _pad_to_multiple(x, block_len, axis, pad_value=0) - num_blocks = x.shape[axis] // block_len - output_shape = x.shape[:axis] + (num_blocks, block_len) + x.shape[(axis + 1) :] - return x.reshape(output_shape) - - -def _concatenate_3_blocks(x: jnp.ndarray, block_axis: int, sequence_axis: int, pad_value: int = 0) -> jnp.ndarray: - """Concatenate three consecutive blocks for each input block for local attentiont. - For more information, see: https://huggingface.co/papers/2112.07916. - """ - num_blocks = x.shape[block_axis] - - pad = [(0, 0)] * x.ndim - pad[block_axis] = (1, 1) - # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len] - x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value) - - blocks_list: list[np.array] = [] - for i in range(3): - # We use indexing approach here: - # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs - indices = [slice(0, None)] * x.ndim - indices[block_axis] = slice(i, i + num_blocks) - indices = tuple(indices) - blocks_list.append(x[indices]) - return jnp.concatenate(blocks_list, axis=sequence_axis) # [batch_size, num_blocks, 3 * block_len, ...] - - -def _make_3block_relative_position_ids(block_len: int) -> jnp.ndarray: - """Makes 3-blocked relative position ids for local attention.""" - position_ids = jnp.arange(3 * block_len, dtype=jnp.int32) - center_position_ids = position_ids[block_len:-block_len] - relative_position_ids = position_ids[None, :] - center_position_ids[:, None] # [block_len, 3 * block_len] - return relative_position_ids - - -def _mask_local_attention_mask(local_attention_mask: np.ndarray, block_len: int) -> jnp.ndarray: - """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius.""" - relative_position_ids = _make_3block_relative_position_ids(block_len) - locality_mask = jnp.abs(relative_position_ids) < block_len - locality_mask = locality_mask[None, None, :, :] - return jnp.logical_and(local_attention_mask, locality_mask) - - -def _get_local_attention_mask(attention_mask: np.ndarray, block_len: int) -> jnp.ndarray: - """Prepare attention mask to be applied for a local attention.""" - # [batch_size, num_blocks, block_len] - _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, axis=1) - # [batch_size, num_block, 3 * block_len] - _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_axis=1, sequence_axis=2) - - _blocked_attention_mask = _blocked_attention_mask[..., None] - _3blocked_attention_mask = _3blocked_attention_mask[..., None, :] - # [batch_size, num_block, block_len, 3 * block_len] - local_attention_mask = jnp.logical_and(_blocked_attention_mask, _3blocked_attention_mask) - local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len) - # [batch_size, 1, num_block, block_len, 3 * block_len] - return local_attention_mask[:, None, ...] - - -def _make_global_fixed_block_ids(attention_mask: np.ndarray, global_block_size: int) -> tuple[jnp.ndarray, np.ndarray]: - """Obtain the "fixed block" global id corresponding to each input token. - - This implementation is a simplified version of the original Flaxformr implementation adopted from: - https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py. - - In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for - the whole fixed block, are assigned to the preceding block. - - Padding tokens from the original sequence are represented by -1. - """ - batch_size, seq_len = attention_mask.shape[:2] - - def handle_orphan_tokens(block_ids: np.ndarray) -> jnp.ndarray: - block_ends = (jnp.arange(seq_len) % global_block_size) == global_block_size - 1 - true_block_ends = jnp.logical_and(block_ends, block_ids >= 0) - full_blocks = true_block_ends.sum(-1)[..., None] - block_ids = jnp.minimum(block_ids, full_blocks - 1) - return block_ids - - fixed_block_mask = jnp.ones_like(attention_mask) / global_block_size - fixed_block_mask = jnp.cumsum(fixed_block_mask, axis=1) - fixed_block_mask - mask = jnp.where(attention_mask != 0.0, 1.0, -1000.0) - global_block_ids = jnp.maximum( - jnp.floor(mask + fixed_block_mask - 1.0), jnp.array(-1.0, dtype=attention_mask.dtype) - ) - # set padding tokens to -1 - global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1) - # [batch_size, seq_len] - global_block_ids = handle_orphan_tokens(global_block_ids) - num_globals = seq_len // global_block_size - - # [batch_size, seq_len // global_block_size] - if num_globals > 0: - _sequence_block_ids_max = jnp.repeat(global_block_ids.max(axis=-1)[:, None], repeats=num_globals, axis=1) - else: - _sequence_block_ids_max = jnp.zeros((batch_size, 0), dtype=global_block_ids.dtype) - global_segment_ids = jnp.cumsum(jnp.ones((batch_size, num_globals)), axis=-1) - 1 - global_segment_ids = jnp.where(global_segment_ids <= _sequence_block_ids_max, 1, 0) - return global_block_ids, global_segment_ids - - -def _make_side_relative_position_ids(attention_mask: np.ndarray, global_block_size: int) -> np.ndarray: - """Create the relative position tensor for local -> global attention.""" - block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size) - global_seq_len = global_segment_ids.shape[-1] - global_positions = jnp.arange(global_seq_len) - side_relative_position = global_positions - block_ids[..., None] - return side_relative_position - - -def _create_global_aggregates(hidden_states: np.ndarray, block_ids: np.ndarray, global_seq_len: int) -> np.ndarray: - """Compute individual block aggregates by summing over individual blocks.""" - # (batch..., seq_len, global_seq_len)) - one_hot_block_ids = jax.nn.one_hot(block_ids, global_seq_len) - return jnp.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids) - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerNorm with T5->LongT5 -class FlaxLongT5LayerNorm(nn.Module): - hidden_size: int - dtype: jnp.dtype = jnp.float32 - eps: float = 1e-6 - weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones - - def setup(self): - self.weight = self.param("weight", self.weight_init, (self.hidden_size,)) - - def __call__(self, hidden_states): - """ - Construct a layernorm module in the LongT5 style; No bias and no subtraction of mean. - """ - # layer norm should always be calculated in float32 - variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) - hidden_states = hidden_states / jnp.sqrt(variance + self.eps) - - return self.weight * hidden_states - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseActDense with T5->LongT5 -class FlaxLongT5DenseActDense(nn.Module): - config: LongT5Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) - wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) - - self.wi = nn.Dense( - self.config.d_ff, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std), - dtype=self.dtype, - ) - self.wo = nn.Dense( - self.config.d_model, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wo_init_std), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - self.act = ACT2FN[self.config.dense_act_fn] - - def __call__(self, hidden_states, deterministic=True): - hidden_states = self.wi(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.wo(hidden_states) - return hidden_states - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseGatedActDense with T5->LongT5 -class FlaxLongT5DenseGatedActDense(nn.Module): - config: LongT5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) - wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) - - self.wi_0 = nn.Dense( - self.config.d_ff, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std), - dtype=self.dtype, - ) - self.wi_1 = nn.Dense( - self.config.d_ff, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std), - dtype=self.dtype, - ) - self.wo = nn.Dense( - self.config.d_model, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wo_init_std), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - self.act = ACT2FN[self.config.dense_act_fn] - - def __call__(self, hidden_states, deterministic): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.wo(hidden_states) - return hidden_states - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerFF with T5->LongT5 -class FlaxLongT5LayerFF(nn.Module): - config: LongT5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - if self.config.is_gated_act: - self.DenseReluDense = FlaxLongT5DenseGatedActDense(self.config, dtype=self.dtype) - else: - self.DenseReluDense = FlaxLongT5DenseActDense(self.config, dtype=self.dtype) - - self.layer_norm = FlaxLongT5LayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__(self, hidden_states, deterministic=True): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic) - hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic) - return hidden_states - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention with T5->LongT5 -class FlaxLongT5Attention(nn.Module): - config: LongT5Config - has_relative_attention_bias: bool = False - causal: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.relative_attention_num_buckets = self.config.relative_attention_num_buckets - self.relative_attention_max_distance = self.config.relative_attention_max_distance - self.d_model = self.config.d_model - self.key_value_proj_dim = self.config.d_kv - self.n_heads = self.config.num_heads - self.dropout = self.config.dropout_rate - self.inner_dim = self.n_heads * self.key_value_proj_dim - - q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) - kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) - o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) - - self.q = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(q_init_std), - dtype=self.dtype, - ) - self.k = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - self.v = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - self.o = nn.Dense( - self.d_model, - use_bias=False, - kernel_init=jax.nn.initializers.normal(o_init_std), - dtype=self.dtype, - ) - - if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embed( - self.relative_attention_num_buckets, - self.n_heads, - embedding_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - - @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0) * num_buckets - relative_position = jnp.abs(relative_position) - else: - relative_position = -jnp.clip(relative_position, a_max=0) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) - ) - relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) - - relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) - - return relative_buckets.astype("i4") - - def compute_bias(self, query_length, key_length): - """Compute binned relative position bias""" - context_position = jnp.arange(query_length, dtype="i4")[:, None] - memory_position = jnp.arange(key_length, dtype="i4")[None, :] - - relative_position = memory_position - context_position - relative_position_bucket = self._relative_position_bucket( - relative_position, - bidirectional=(not self.causal), - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - - values = self.relative_attention_bias(relative_position_bucket) - values = values.transpose((2, 0, 1))[None, :, :, :] - return values - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = jax.lax.dynamic_update_slice(cached_key.value, key, indices) - value = jax.lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions - # that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def _create_position_bias( - self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift - ): - cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache) - key_length = key_states.shape[1] - query_length = key_length if cache_is_filled else query_states.shape[1] - - if self.has_relative_attention_bias: - position_bias = self.compute_bias(query_length, key_length) - elif attention_mask is not None: - position_bias = jnp.zeros_like(attention_mask) - else: - position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype) - - # if key and values are already calculated, only the last query position bias should be taken - if cache_is_filled: - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - position_bias = jax.lax.dynamic_slice( - position_bias, - (0, 0, causal_attention_mask_shift, 0), - (1, self.n_heads, seq_length, max_decoder_length), - ) - return position_bias - - def __call__( - self, - hidden_states, - attention_mask=None, - key_value_states=None, - position_bias=None, - use_cache=False, - output_attentions=False, - deterministic=True, - init_cache=False, - ): - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - batch_size, seq_length = hidden_states.shape[:2] - - # q, k, v projections - query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) - key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) - value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) - - # reshape to (batch_size, seq_length, n_heads, head_dim) - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # counter-act scaling in dot_product_attention_weights function - query_states *= jnp.sqrt(query_states.shape[-1]) - - # for fast decoding causal attention mask should be shifted - causal_attention_mask_shift = ( - self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0 - ) - # create causal attention_mask; attention_mask has to be defined when model is causal - if self.causal: - causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") - - # fast decoding for generate requires special attention_mask - if self.has_variable("cache", "cached_key"): - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_attention_mask = jax.lax.dynamic_slice( - causal_attention_mask, - (0, 0, causal_attention_mask_shift, 0), - (1, 1, seq_length, max_decoder_length), - ) - - # broadcast causal attention mask & attention mask to fit for merge - causal_attention_mask = jnp.broadcast_to( - causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] - ) - attention_mask = jnp.broadcast_to( - jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape - ) - attention_mask = combine_masks(attention_mask, causal_attention_mask) - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # replace masked positions with -10_000 - if attention_mask is not None: - mask_value = jnp.finfo(self.dtype).min - attention_mask = jax.lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, mask_value).astype(self.dtype), - ) - - if position_bias is None: - # compute position bias (only for first layer) - position_bias = self._create_position_bias( - key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift - ) - - if attention_mask is not None: - position_bias = position_bias + attention_mask - - # create dropout rng - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - # Softmax(QK^T) - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=position_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - ) - - # multiply with value states - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - - # bring back to (batch_size, seq_length, d_model) - attn_output = self._merge_heads(attn_output) - - # apply output matrix - attn_output = self.o(attn_output) - - outputs = (attn_output, position_bias) - - if output_attentions: - outputs = outputs + (attn_weights,) - - return outputs - - -class FlaxLongT5LocalAttention(nn.Module): - config: LongT5Config - has_relative_attention_bias: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.relative_attention_num_buckets = self.config.relative_attention_num_buckets - self.relative_attention_max_distance = self.config.relative_attention_max_distance - self.d_model = self.config.d_model - self.key_value_proj_dim = self.config.d_kv - self.n_heads = self.config.num_heads - self.local_radius = self.config.local_radius - self.block_len = self.local_radius + 1 - self.dropout = self.config.dropout_rate - self.inner_dim = self.n_heads * self.key_value_proj_dim - - q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) - kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) - o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) - - self.q = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(q_init_std), - dtype=self.dtype, - ) - self.k = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - self.v = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - self.o = nn.Dense( - self.d_model, - use_bias=False, - kernel_init=jax.nn.initializers.normal(o_init_std), - dtype=self.dtype, - ) - - if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embed( - self.relative_attention_num_buckets, - self.n_heads, - embedding_init=jax.nn.initializers.normal(kv_init_std), - ) - - @staticmethod - # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0) * num_buckets - relative_position = jnp.abs(relative_position) - else: - relative_position = -jnp.clip(relative_position, a_max=0) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) - ) - relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) - - relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) - - return relative_buckets.astype("i4") - - def compute_bias(self, block_length: int): - """Compute binned relative position bias""" - memory_position = jnp.arange(3 * block_length, dtype="i4") - context_position = memory_position[block_length:-block_length] - - relative_position = memory_position[None, :] - context_position[:, None] - relative_position_bucket = self._relative_position_bucket( - relative_position, - bidirectional=True, - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - - values = self.relative_attention_bias(relative_position_bucket) - values = values.transpose((2, 0, 1))[None, None, :, :, :] - return values - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim) - - def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray: - # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len) - if self.has_relative_attention_bias: - position_bias = self.compute_bias(block_len) - elif attention_mask is not None: - position_bias = jnp.zeros_like(attention_mask) - else: - position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype) - - return position_bias - - def __call__( - self, - hidden_states, - attention_mask=None, - key_value_states=None, - position_bias=None, - output_attentions=False, - deterministic=True, - ): - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - batch_size, seq_length = hidden_states.shape[:2] - - # q, k, v projections - query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) - key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) - value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) - - # reshape to (batch_size, seq_length, n_heads, head_dim) - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, head_dim) - query_states = _split_into_blocks(query_states, self.block_len, axis=1) - key_states = _split_into_blocks(key_states, self.block_len, axis=1) - value_states = _split_into_blocks(value_states, self.block_len, axis=1) - - # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head) - key_states = _concatenate_3_blocks(key_states, block_axis=1, sequence_axis=2) - value_states = _concatenate_3_blocks(value_states, block_axis=1, sequence_axis=2) - - # counter-act scaling in dot_product_attention_weights function - query_states *= jnp.sqrt(query_states.shape[-1]) - - if attention_mask is not None: - attention_mask = _get_local_attention_mask(attention_mask, self.block_len) - - # replace masked positions with -10_000 - attention_mask = jax.lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e10).astype(self.dtype), - ) - - if position_bias is None: - # compute position bias (only for first layer) - position_bias = self._create_position_bias(self.block_len, attention_mask) - - if attention_mask is not None: - position_bias = position_bias + attention_mask.swapaxes(1, 2) - - # create dropout rng - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - # Softmax(QK^T) - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=position_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - ) - - # multiply with value states - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - - # bring back to (batch_size, seq_length, d_model) - attn_output = self._merge_heads(attn_output) - attn_output = attn_output[:, :seq_length, :] - - # apply output matrix - attn_output = self.o(attn_output) - - outputs = (attn_output, position_bias) - - if output_attentions: - outputs = outputs + (attn_weights,) - - return outputs - - -class FlaxLongT5TransientGlobalAttention(nn.Module): - config: LongT5Config - has_relative_attention_bias: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.relative_attention_num_buckets = self.config.relative_attention_num_buckets - self.relative_attention_max_distance = self.config.relative_attention_max_distance - self.d_model = self.config.d_model - self.key_value_proj_dim = self.config.d_kv - self.n_heads = self.config.num_heads - self.local_radius = self.config.local_radius - self.block_len = self.local_radius + 1 - self.global_block_size = self.config.global_block_size - self.dropout = self.config.dropout_rate - self.inner_dim = self.n_heads * self.key_value_proj_dim - - q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) - kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) - o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) - - self.q = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(q_init_std), - dtype=self.dtype, - ) - self.k = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - self.v = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - self.o = nn.Dense( - self.d_model, - use_bias=False, - kernel_init=jax.nn.initializers.normal(o_init_std), - dtype=self.dtype, - ) - - if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embed( - self.relative_attention_num_buckets, - self.n_heads, - embedding_init=jax.nn.initializers.normal(kv_init_std), - ) - - # Relativen attention bias & Layer norm for global attention - if self.has_relative_attention_bias: - self.global_relative_attention_bias = nn.Embed( - self.relative_attention_num_buckets, - self.n_heads, - embedding_init=jax.nn.initializers.normal(kv_init_std), - ) - self.global_input_layer_norm = FlaxLongT5LayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - - @staticmethod - # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0) * num_buckets - relative_position = jnp.abs(relative_position) - else: - relative_position = -jnp.clip(relative_position, a_max=0) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) - ) - relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) - - relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) - - return relative_buckets.astype("i4") - - def compute_bias(self, block_length: int): - """Compute binned relative position bias""" - memory_position = jnp.arange(3 * block_length, dtype="i4") - context_position = memory_position[block_length:-block_length] - - relative_position = memory_position[None, :] - context_position[:, None] - relative_position_bucket = self._relative_position_bucket( - relative_position, - bidirectional=True, - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - - values = self.relative_attention_bias(relative_position_bucket) - values = values.transpose((2, 0, 1))[None, None, :, :, :] - return values - - def compute_side_bias(self, attention_mask: np.ndarray, global_segment_ids: np.ndarray) -> np.ndarray: - # (batch_size, 1, 1, seq_len, global_seq_len) - side_attention_mask = jnp.equal(attention_mask[..., None], global_segment_ids[:, None, :])[:, None, ...] - attention_side_bias = jax.lax.select( - side_attention_mask > 0, - jnp.full(side_attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(side_attention_mask.shape, -1e10).astype(self.dtype), - ) - # (batch_size, seq_len, global_seq_len) - side_relative_position = _make_side_relative_position_ids(attention_mask, self.global_block_size) - side_relative_position_bucket = self._relative_position_bucket( - side_relative_position, - bidirectional=True, - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - # (batch_size, seq_len, global_seq_len, num_heads) - side_bias = self.global_relative_attention_bias(side_relative_position_bucket) - - # (batch_size, 1, num_heads, seq_len, global_seq_len) - side_bias = jnp.transpose(side_bias, (0, 3, 1, 2)) - # (batch_size, num_heads, seq_len, global_seq_len) - attention_side_bias = attention_side_bias + side_bias - return attention_side_bias - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim) - - def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray: - # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len) - if self.has_relative_attention_bias: - position_bias = self.compute_bias(block_len) - elif attention_mask is not None: - position_bias = jnp.zeros_like(attention_mask) - else: - position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype) - - return position_bias - - def __call__( - self, - hidden_states, - attention_mask=None, - key_value_states=None, - position_bias=None, - output_attentions=False, - deterministic=True, - ): - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - batch_size, seq_length = hidden_states.shape[:2] - - # Prepare components for transient-global attention - # Obtain block_ids and global_segment_ids - # global_seq_len := seq_len // self.global_block_size - # shapes: (batch_size, seq_len) & (batch_size, global_seq_len) - block_ids, global_segment_ids = _make_global_fixed_block_ids( - attention_mask if attention_mask is not None else jnp.ones((batch_size, seq_length)), - self.global_block_size, - ) - # Create global inputs - _global_seq_len = global_segment_ids.shape[-1] - global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len) - global_inputs = self.global_input_layer_norm(global_inputs) - - # q, k, v projections - query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) - key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) - value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) - - # reshape to (batch_size, seq_length, n_heads, head_dim) - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # Get global/side key/value_states - side_key_states = self.k(global_inputs) - side_value_states = self.v(global_inputs) - - # reshape to (batch_size, global_seq_len, n_heads, head_dim) - side_key_states = self._split_heads(side_key_states) - side_value_states = self._split_heads(side_value_states) - - # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, head_dim) - query_states = _split_into_blocks(query_states, self.block_len, axis=1) - key_states = _split_into_blocks(key_states, self.block_len, axis=1) - value_states = _split_into_blocks(value_states, self.block_len, axis=1) - - # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head) - key_states = _concatenate_3_blocks(key_states, block_axis=1, sequence_axis=2) - value_states = _concatenate_3_blocks(value_states, block_axis=1, sequence_axis=2) - - # Tile side inputs across local key/value blocks - # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head) - reps = [1] * (side_key_states.ndim + 1) - reps[1] = key_states.shape[1] - side_key_states = jnp.tile(side_key_states[:, None, ...], reps) - side_value_states = jnp.tile(side_value_states[:, None, ...], reps) - - # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones - # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head) - key_states = jnp.concatenate((key_states, side_key_states), axis=2) - value_states = jnp.concatenate((value_states, side_value_states), axis=2) - - # counter-act scaling in dot_product_attention_weights function - query_states *= jnp.sqrt(query_states.shape[-1]) - - if attention_mask is not None: - local_attention_mask = _get_local_attention_mask(attention_mask, self.block_len) - local_attention_mask = jax.lax.select( - local_attention_mask > 0, - jnp.full(local_attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(local_attention_mask.shape, -1e10).astype(self.dtype), - ) - else: - local_attention_mask = None - - if position_bias is None: - # compute position bias (only for first layer) - position_bias = self._create_position_bias(self.block_len, attention_mask) - if local_attention_mask is not None: - position_bias = position_bias + local_attention_mask.swapaxes(1, 2) - - # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len) - if attention_mask is None: - attention_mask = jnp.ones((batch_size, seq_length)) - side_position_bias = self.compute_side_bias(attention_mask, global_segment_ids) - side_position_bias = _split_into_blocks(side_position_bias, self.block_len, axis=-2) - side_position_bias = jnp.swapaxes(side_position_bias, 1, 2) - position_bias = jnp.concatenate((position_bias, side_position_bias), axis=-1) - - # create dropout rng - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - # Softmax(QK^T) - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=position_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - ) - - # multiply with value states - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - - # bring back to (batch_size, seq_length, d_model) - attn_output = self._merge_heads(attn_output) - attn_output = attn_output[:, :seq_length, :] - - # apply output matrix - attn_output = self.o(attn_output) - - outputs = (attn_output, position_bias) - - if output_attentions: - outputs = outputs + (attn_weights,) - - return outputs - - -class FlaxLongT5LayerLocalSelfAttention(nn.Module): - """Local self attention used in encoder""" - - config: LongT5Config - has_relative_attention_bias: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.LocalSelfAttention = FlaxLongT5LocalAttention( - self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype - ) - self.layer_norm = FlaxLongT5LayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_bias=None, - output_attentions=False, - deterministic=True, - **kwargs: Any, # to accept init_cache kwargs - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.LocalSelfAttention( - normed_hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - -class FlaxLongT5LayerTransientGlobalSelfAttention(nn.Module): - """Transient-Global self attention used in encoder""" - - config: LongT5Config - has_relative_attention_bias: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.TransientGlobalSelfAttention = FlaxLongT5TransientGlobalAttention( - self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype - ) - self.layer_norm = FlaxLongT5LayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_bias=None, - output_attentions=False, - deterministic=True, - **kwargs: Any, # to accept init_cache kwargs - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.TransientGlobalSelfAttention( - normed_hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerSelfAttention with T5->LongT5 -class FlaxLongT5LayerSelfAttention(nn.Module): - config: LongT5Config - has_relative_attention_bias: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.SelfAttention = FlaxLongT5Attention( - self.config, - has_relative_attention_bias=self.has_relative_attention_bias, - causal=self.config.causal, - dtype=self.dtype, - ) - self.layer_norm = FlaxLongT5LayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_bias=None, - output_attentions=False, - deterministic=True, - init_cache=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - normed_hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - init_cache=init_cache, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCrossAttention with T5->LongT5 -class FlaxLongT5LayerCrossAttention(nn.Module): - config: LongT5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.EncDecAttention = FlaxLongT5Attention( - self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype - ) - self.layer_norm = FlaxLongT5LayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - output_attentions=False, - deterministic=True, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.EncDecAttention( - normed_hidden_states, - attention_mask=attention_mask, - key_value_states=key_value_states, - position_bias=position_bias, - output_attentions=output_attentions, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - -class FlaxLongT5Block(nn.Module): - config: LongT5Config - has_relative_attention_bias: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.causal = self.config.causal - if self.causal: - attention_layer = FlaxLongT5LayerSelfAttention - elif self.config.encoder_attention_type == "local": - attention_layer = FlaxLongT5LayerLocalSelfAttention - elif self.config.encoder_attention_type == "transient-global": - attention_layer = FlaxLongT5LayerTransientGlobalSelfAttention - else: - raise ValueError( - "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, " - f"but got {self.config.encoder_attention_type}." - ) - self.layer = ( - attention_layer( - self.config, - has_relative_attention_bias=self.has_relative_attention_bias, - name=str(0), - dtype=self.dtype, - ), - ) - feed_forward_index = 1 - if self.causal: - self.layer += (FlaxLongT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) - feed_forward_index += 1 - - self.layer += (FlaxLongT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) - - # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Block.__call__ with T5->LongT5 - def __call__( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - output_attentions=False, - return_dict=True, - deterministic=True, - init_cache=False, - ): - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - init_cache=init_cache, - ) - hidden_states = self_attention_outputs[0] - attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights - - do_cross_attention = self.causal and encoder_hidden_states is not None - if do_cross_attention: - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - ) - hidden_states = cross_attention_outputs[0] - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[1:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - outputs = outputs + attention_outputs - - # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - return outputs - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCollection with T5->LongT5 -class FlaxLongT5LayerCollection(nn.Module): - config: LongT5Config - has_relative_attention_bias: bool - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layer = FlaxLongT5Block( - self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype - ) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - output_attentions=False, - deterministic=True, - init_cache=False, - ): - return self.layer( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - init_cache=init_cache, - ) - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5BlockCollection with T5->LongT5 -class FlaxLongT5BlockCollection(nn.Module): - config: LongT5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.causal = self.config.causal - if self.gradient_checkpointing: - FlaxLongT5CheckpointLayer = remat(FlaxLongT5LayerCollection, static_argnums=(6, 7, 8)) - self.blocks = [ - FlaxLongT5CheckpointLayer( - self.config, - has_relative_attention_bias=(i == 0), - dtype=self.dtype, - name=str(i), - ) - for i in range(self.config.num_layers) - ] - else: - self.blocks = [ - FlaxLongT5LayerCollection( - self.config, - has_relative_attention_bias=(i == 0), - dtype=self.dtype, - name=str(i), - ) - for i in range(self.config.num_layers) - ] - - def __call__( - self, - hidden_states=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - output_attentions: bool = False, - output_hidden_states: bool = False, - deterministic: bool = True, - init_cache: bool = False, - ): - # Prepare head mask if needed - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if (output_attentions and self.causal) else None - position_bias = None - encoder_decoder_position_bias = None - - for i, layer_module in enumerate(self.blocks): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask, - position_bias, - encoder_hidden_states, - encoder_attention_mask, - encoder_decoder_position_bias, - output_attentions, - deterministic, - init_cache, - ) - - hidden_states = layer_outputs[0] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[1] - - if self.causal and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[2],) - if self.causal: - all_cross_attentions = all_cross_attentions + (layer_outputs[4],) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Stack with T5->LongT5 -class FlaxLongT5Stack(nn.Module): - config: LongT5Config - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.causal = self.config.causal - - self.block = FlaxLongT5BlockCollection( - self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.final_layer_norm = FlaxLongT5LayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - init_cache: bool = False, - ): - hidden_states = self.embed_tokens(input_ids) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - - outputs = self.block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - deterministic=deterministic, - init_cache=init_cache, - ) - - hidden_states = outputs[0] - - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - - # Add last layer - all_hidden_states = None - - if output_hidden_states: - all_hidden_states = outputs.hidden_states - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - if output_hidden_states: - return ( - hidden_states, - all_hidden_states, - ) + outputs[2:] - return (hidden_states,) + outputs[1:] - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -LONGT5_ENCODE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so - you should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5 - Training](./longt5#training). - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -LONGT5_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For training, `decoder_input_ids` should be provided. - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -LONGT5_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so - you should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - [What are input IDs?](../glossary#input-ids) - - To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5 - Training](./longt5#training). - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If - `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5 - Training](./longt5#training). - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*): - Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at - the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class FlaxLongT5PreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = LongT5Config - base_model_prefix = "transformer" - module_class: nn.Module = None - - def __init__( - self, - config: LongT5Config, - input_shape: tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def enable_gradient_checkpointing(self): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=True, - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - - attention_mask = jnp.ones_like(input_ids) - decoder_input_ids = jnp.ones_like(input_ids) - decoder_attention_mask = jnp.ones_like(input_ids) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: jnp.ndarray = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if decoder_input_ids is None: - raise ValueError( - "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed" - " here." - ) - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # prepare decoder inputs - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(LONGT5_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=LongT5Config) - def encode( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") - >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_ids, attention_mask, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_ids, attention_mask, **kwargs) - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=LongT5Config) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration - >>> import jax.numpy as jnp - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") - >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxLongT5Attention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - -LONGT5_START_DOCSTRING = r""" - The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long - Sequences](https://huggingface.co/papers/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo - Ni, Yun-Hsuan Sung and Yinfei Yang. It's an encoder-decoder transformer pre-trained in a text-to-text denoising - generative setting. LongT5 model is an extension of T5 model, and it enables using one of the two different - efficient attention mechanisms - (1) Local attention, or (2) Transient-Global attention. - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`LongT5Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - - -@add_start_docstrings( - "The bare LONGT5 Model transformer outputting raw hidden-stateswithout any specific head on top.", - LONGT5_START_DOCSTRING, -) -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Module with T5->LongT5 -class FlaxLongT5Module(nn.Module): - config: LongT5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), - dtype=self.dtype, - ) - - encoder_config = copy.deepcopy(self.config) - encoder_config.causal = False - self.encoder = FlaxLongT5Stack( - encoder_config, - embed_tokens=self.shared, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - decoder_config = copy.deepcopy(self.config) - decoder_config.causal = True - decoder_config.num_layers = self.config.num_decoder_layers - self.decoder = FlaxLongT5Stack( - decoder_config, - embed_tokens=self.shared, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - def __call__( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - encoder_outputs=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - deterministic: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Encode if needed (training, first prediction pass) - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Model with T5->LongT5 -class FlaxLongT5Model(FlaxLongT5PreTrainedModel): - module_class = FlaxLongT5Module - - -append_call_sample_docstring(FlaxLongT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) - -FLAX_LONGT5_MODEL_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxLongT5Model - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") - >>> model = FlaxLongT5Model.from_pretrained("google/long-t5-local-base") - - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="np" - ... ).input_ids - >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids - - >>> # forward pass - >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ``` -""" - - -overwrite_call_docstring(FlaxLongT5Model, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_MODEL_DOCSTRING) -append_replace_return_docstrings(FlaxLongT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - - -@add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING) -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5ForConditionalGenerationModule with T5->LongT5 -class FlaxLongT5ForConditionalGenerationModule(nn.Module): - config: LongT5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def setup(self): - self.model_dim = self.config.d_model - - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), - dtype=self.dtype, - ) - - encoder_config = copy.deepcopy(self.config) - encoder_config.causal = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = FlaxLongT5Stack( - encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - - decoder_config = copy.deepcopy(self.config) - decoder_config.causal = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = self.config.num_decoder_layers - self.decoder = FlaxLongT5Stack( - decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - kernel_init=jax.nn.initializers.normal(self.config.initializer_factor), - dtype=self.dtype, - ) - - def __call__( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - encoder_outputs=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - deterministic: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Encode - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = encoder_outputs[0] - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - sequence_output = decoder_outputs[0] - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) - - if self.config.tie_word_embeddings: - shared_embedding = self.shared.variables["params"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) - else: - lm_logits = self.lm_head(sequence_output) - - if not return_dict: - return (lm_logits,) + decoder_outputs[1:] + encoder_outputs - - return FlaxSeq2SeqLMOutput( - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel): - module_class = FlaxLongT5ForConditionalGenerationModule - - @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=LongT5Config) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration - >>> import jax.numpy as jnp - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") - >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base") - - >>> text = "summarize: My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxLongT5Attention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): - decoder_module = module._get_decoder_module() - decoder_outputs = decoder_module( - decoder_input_ids, - decoder_attention_mask, - **kwargs, - ) - - sequence_output = decoder_outputs[0] - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.config.d_model**-0.5) - - if self.config.tie_word_embeddings: - shared_embedding = module.shared.variables["params"]["embedding"] - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) - else: - lm_logits = module.lm_head(sequence_output) - - return lm_logits, decoder_outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - extended_attention_mask = jax.lax.dynamic_update_slice( - extended_attention_mask, decoder_attention_mask, (0, 0) - ) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - return model_kwargs - - -FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") - >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base") - - >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np") - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs["input_ids"]).sequences - >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)) - ``` -""" - - -overwrite_call_docstring( - FlaxLongT5ForConditionalGeneration, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING -) -append_replace_return_docstrings( - FlaxLongT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC -) - - -__all__ = ["FlaxLongT5ForConditionalGeneration", "FlaxLongT5Model", "FlaxLongT5PreTrainedModel"] diff --git a/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py deleted file mode 100755 index 1dd77bc36f80..000000000000 --- a/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,59 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert LXMERT checkpoint.""" - -import argparse - -import torch - -from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): - # Initialise PyTorch model - config = LxmertConfig.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") - model = LxmertForPreTraining(config) - - # Load weights from tf checkpoint - load_tf_weights_in_lxmert(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - torch.save(model.state_dict(), pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.", - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/lxmert/modeling_tf_lxmert.py b/src/transformers/models/lxmert/modeling_tf_lxmert.py deleted file mode 100644 index aee9fb785796..000000000000 --- a/src/transformers/models/lxmert/modeling_tf_lxmert.py +++ /dev/null @@ -1,1660 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team, and the -# Lxmert Authors. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 LXMERT model.""" - -from __future__ import annotations - -import warnings -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - get_initializer, - keras, - keras_serializable, - shape_list, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_lxmert import LxmertConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "unc-nlp/lxmert-base-uncased" -_CONFIG_FOR_DOC = "LxmertConfig" - - -@dataclass -class TFLxmertModelOutput(ModelOutput): - """ - Lxmert's outputs that contain the last hidden states, pooled outputs, and attention probabilities for the language, - visual, and, cross-modality encoders. (note: the visual encoder in Lxmert is referred to as the "relation-ship" - encoder") - - - Args: - language_output (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the language encoder. - vision_output (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the visual encoder. - pooled_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification, CLS, token) further processed - by a Linear layer and a Tanh activation function. The Linear - language_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for input features + one for the output of each cross-modality layer) of shape - `(batch_size, sequence_length, hidden_size)`. - vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for input features + one for the output of each cross-modality layer) of shape - `(batch_size, sequence_length, hidden_size)`. - language_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - cross_encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - """ - - language_output: tf.Tensor | None = None - vision_output: tf.Tensor | None = None - pooled_output: tf.Tensor | None = None - language_hidden_states: tuple[tf.Tensor] | None = None - vision_hidden_states: tuple[tf.Tensor] | None = None - language_attentions: tuple[tf.Tensor] | None = None - vision_attentions: tuple[tf.Tensor] | None = None - cross_encoder_attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFLxmertForPreTrainingOutput(ModelOutput): - """ - Output type of [`LxmertForPreTraining`]. - - Args: - loss (*optional*, returned when `labels` is provided, `tf.Tensor` of shape `(1,)`): - Total loss as the sum of the masked language modeling loss and the next sequence prediction - (classification) loss. - prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cross_relationship_score (`tf.Tensor` of shape `(batch_size, 2)`): - Prediction scores of the textual matching objective (classification) head (scores of True/False - continuation before SoftMax). - question_answering_score (`tf.Tensor` of shape `(batch_size, n_qa_answers)`): - Prediction scores of question answering objective (classification). - language_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for input features + one for the output of each cross-modality layer) of shape - `(batch_size, sequence_length, hidden_size)`. - vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for input features + one for the output of each cross-modality layer) of shape - `(batch_size, sequence_length, hidden_size)`. - language_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - cross_encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - - """ - - loss: tf.Tensor | None = None - prediction_logits: tf.Tensor | None = None - cross_relationship_score: tf.Tensor | None = None - question_answering_score: tf.Tensor | None = None - language_hidden_states: tuple[tf.Tensor] | None = None - vision_hidden_states: tuple[tf.Tensor] | None = None - language_attentions: tuple[tf.Tensor] | None = None - vision_attentions: tuple[tf.Tensor] | None = None - cross_encoder_attentions: tuple[tf.Tensor] | None = None - - -class TFLxmertVisualFeatureEncoder(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - # Object feature encoding - self.visn_fc = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="visn_fc", - ) - self.visn_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="visn_layer_norm") - - # Box position encoding - self.box_fc = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="box_fc", - ) - self.box_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="box_layer_norm") - - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.feat_dim = config.visual_feat_dim - self.pos_dim = config.visual_pos_dim - self.config = config - - def call(self, visn_input, training=False): - feats, boxes = visn_input - - x = self.visn_fc(feats) - x = self.visn_layer_norm(x) - y = self.box_fc(boxes) - y = self.box_layer_norm(y) - output = (x + y) / 2 - - output = self.dropout(output, training=training) - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "visn_fc", None) is not None: - with tf.name_scope(self.visn_fc.name): - self.visn_fc.build([None, None, self.feat_dim]) - if getattr(self, "visn_layer_norm", None) is not None: - with tf.name_scope(self.visn_layer_norm.name): - self.visn_layer_norm.build([None, None, self.config.hidden_size]) - if getattr(self, "box_fc", None) is not None: - with tf.name_scope(self.box_fc.name): - self.box_fc.build([None, None, self.pos_dim]) - if getattr(self, "box_layer_norm", None) is not None: - with tf.name_scope(self.box_layer_norm.name): - self.box_layer_norm.build([None, None, self.config.hidden_size]) - - -class TFLxmertEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.hidden_size], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.hidden_size], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - def call(self, input_ids=None, token_type_ids=None, inputs_embeds=None, training=False): - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -class TFLxmertAttention(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads}" - ) - - self.num_attention_heads = config.num_attention_heads - assert config.hidden_size % config.num_attention_heads == 0 - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="query", - ) - self.key = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="key", - ) - self.value = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - name="value", - ) - - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.ctx_dim = config.hidden_size - self.config = config - - def transpose_for_scores(self, x, batch_size): - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) - return tf.transpose(x, perm=[0, 2, 1, 3]) - - def call(self, hidden_states, context, attention_mask, output_attentions, training=False): - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(context) - mixed_value_layer = self.value(context) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = tf.matmul( - query_layer, key_layer, transpose_b=True - ) # (batch size, num_heads, seq_len_q, seq_len_k) - dk = tf.cast(shape_list(key_layer)[-1], dtype=attention_scores.dtype) # scale attention_scores - attention_scores = attention_scores / tf.math.sqrt(dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFLxmertModel call() function) - attention_mask = tf.cast(attention_mask, dtype=attention_scores.dtype) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs, training=training) - context_layer = tf.matmul(attention_probs, value_layer) - - context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) - context_layer = tf.reshape( - context_layer, (batch_size, -1, self.all_head_size) - ) # (batch_size, seq_len_q, all_head_size) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.ctx_dim]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.ctx_dim]) - - -class TFLxmertIntermediate(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.intermediate_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFLxmertOutput(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states, input_tensor, training=False): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFLxmertAttentionOutput(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states, input_tensor, training=False): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFLxmertSelfAttentionLayer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.self = TFLxmertAttention(config, name="self") - self.attention_output = TFLxmertAttentionOutput(config, name="output") - - def call(self, input_tensor, attention_mask, output_attentions, training=False): - # Self attention attends to itself, thus keys and queries are the same (input_tensor). - self_output = self.self(input_tensor, input_tensor, attention_mask, output_attentions) - if output_attentions: - attention_probs = self_output[1] - attention_output = self.attention_output(self_output[0], input_tensor) - return (attention_output, attention_probs) if output_attentions else (attention_output,) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self", None) is not None: - with tf.name_scope(self.self.name): - self.self.build(None) - if getattr(self, "attention_output", None) is not None: - with tf.name_scope(self.attention_output.name): - self.attention_output.build(None) - - -class TFLxmertCrossAttentionLayer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.att = TFLxmertAttention(config, name="att") - self.attention_output = TFLxmertAttentionOutput(config, name="output") - - def call( - self, - input_tensor, - ctx_tensor, - ctx_att_mask, - output_attentions=False, - training=False, - ): - output = self.att(input_tensor, ctx_tensor, ctx_att_mask, output_attentions, training=training) - if output_attentions: - attention_probs = output[1] - attention_output = self.attention_output(output[0], input_tensor, training=training) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "att", None) is not None: - with tf.name_scope(self.att.name): - self.att.build(None) - if getattr(self, "attention_output", None) is not None: - with tf.name_scope(self.attention_output.name): - self.attention_output.build(None) - - -class TFLxmertLayer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.attention = TFLxmertSelfAttentionLayer(config, name="attention") - self.intermediate = TFLxmertIntermediate(config, name="intermediate") - self.transformer_output = TFLxmertOutput(config, name="output") - - def call(self, hidden_states, attention_mask, output_attentions, training=False): - attention_outputs = self.attention(hidden_states, attention_mask, output_attentions, training=training) - attention_output = attention_outputs[0] - intermediate_output = self.intermediate(attention_output) - layer_output = self.transformer_output(intermediate_output, attention_output, training=training) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "transformer_output", None) is not None: - with tf.name_scope(self.transformer_output.name): - self.transformer_output.build(None) - - -class TFLxmertXLayer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.visual_attention = TFLxmertCrossAttentionLayer(config, name="visual_attention") - - # Self-attention Layers - self.lang_self_att = TFLxmertSelfAttentionLayer(config, name="lang_self_att") - self.visn_self_att = TFLxmertSelfAttentionLayer(config, name="visn_self_att") - - # Intermediate and Output Layers (FFNs) - self.lang_inter = TFLxmertIntermediate(config, name="lang_inter") - self.lang_output = TFLxmertOutput(config, name="lang_output") - self.visn_inter = TFLxmertIntermediate(config, name="visn_inter") - self.visn_output = TFLxmertOutput(config, name="visn_output") - - def cross_att( - self, - lang_input, - lang_attention_mask, - visn_input, - visn_attention_mask, - output_attentions, - training=False, - ): - # Cross Attention - - # Keras saving and loading model *does not work* with the same inputs for two layers. - lang_attention_lang_input = tf.identity(lang_input) - visn_attention_lang_input = tf.identity(lang_input) - lang_attention_visn_input = tf.identity(visn_input) - visn_attention_visn_input = tf.identity(visn_input) - - lang_att_output = self.visual_attention( - lang_attention_lang_input, - lang_attention_visn_input, - visn_attention_mask, - output_attentions=output_attentions, - training=training, - ) - visn_att_output = self.visual_attention( - visn_attention_visn_input, - visn_attention_lang_input, - lang_attention_mask, - output_attentions=output_attentions, - training=training, - ) - return lang_att_output, visn_att_output - - def self_att( - self, - lang_input, - lang_attention_mask, - visn_input, - visn_attention_mask, - training=False, - ): - # Self Attention - output_attentions = False - lang_att_output = self.lang_self_att(lang_input, lang_attention_mask, output_attentions, training=training) - visn_att_output = self.visn_self_att(visn_input, visn_attention_mask, output_attentions, training=training) - return lang_att_output[0], visn_att_output[0] - - def output_fc(self, lang_input, visn_input, training=False): - # FC layers - lang_inter_output = self.lang_inter(lang_input) - visn_inter_output = self.visn_inter(visn_input) - - # Layer output - lang_output = self.lang_output(lang_inter_output, lang_input, training) - visn_output = self.visn_output(visn_inter_output, visn_input, training) - return lang_output, visn_output - - def call( - self, - lang_feats, - lang_attention_mask, - visn_feats, - visn_attention_mask, - output_attentions, - training=False, - ): - lang_att_output = lang_feats - visn_att_output = visn_feats - - lang_att_output, visn_att_output = self.cross_att( - lang_att_output, - lang_attention_mask, - visn_att_output, - visn_attention_mask, - output_attentions, - training=training, - ) - attention_probs = lang_att_output[1:] - lang_att_output, visn_att_output = self.self_att( - lang_att_output[0], - lang_attention_mask, - visn_att_output[0], - visn_attention_mask, - training=training, - ) - lang_output, visn_output = self.output_fc(lang_att_output, visn_att_output, training=training) - - return (lang_output, visn_output, attention_probs[0]) if output_attentions else (lang_output, visn_output) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "visual_attention", None) is not None: - with tf.name_scope(self.visual_attention.name): - self.visual_attention.build(None) - if getattr(self, "lang_self_att", None) is not None: - with tf.name_scope(self.lang_self_att.name): - self.lang_self_att.build(None) - if getattr(self, "visn_self_att", None) is not None: - with tf.name_scope(self.visn_self_att.name): - self.visn_self_att.build(None) - if getattr(self, "lang_inter", None) is not None: - with tf.name_scope(self.lang_inter.name): - self.lang_inter.build(None) - if getattr(self, "lang_output", None) is not None: - with tf.name_scope(self.lang_output.name): - self.lang_output.build(None) - if getattr(self, "visn_inter", None) is not None: - with tf.name_scope(self.visn_inter.name): - self.visn_inter.build(None) - if getattr(self, "visn_output", None) is not None: - with tf.name_scope(self.visn_output.name): - self.visn_output.build(None) - - -class TFLxmertEncoder(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.visn_fc = TFLxmertVisualFeatureEncoder(config, name="visn_fc") - - # Number of layers - self.num_l_layers = config.l_layers - self.num_x_layers = config.x_layers - self.num_r_layers = config.r_layers - - # Layers - # Using self.layer instead of self.l_layer to support loading BERT weights. - self.layer = [TFLxmertLayer(config, name=f"layer_._{i}") for i in range(self.num_l_layers)] - self.x_layers = [TFLxmertXLayer(config, name=f"x_layers_._{i}") for i in range(self.num_x_layers)] - self.r_layers = [TFLxmertLayer(config, name=f"r_layers_._{i}") for i in range(self.num_r_layers)] - self.config = config - - def call( - self, - lang_feats=None, - lang_attention_mask=None, - visual_feats=None, - visual_pos=None, - visual_attention_mask=None, - output_attentions=None, - training=False, - ): - vision_hidden_states = () - language_hidden_states = () - vision_attentions = () if output_attentions or self.config.output_attentions else None - language_attentions = () if output_attentions or self.config.output_attentions else None - cross_encoder_attentions = () if output_attentions or self.config.output_attentions else None - - visual_feats = self.visn_fc([visual_feats, visual_pos], training=training) - - # Run language layers - for layer_module in self.layer: - l_outputs = layer_module(lang_feats, lang_attention_mask, output_attentions, training=training) - lang_feats = l_outputs[0] - language_hidden_states = language_hidden_states + (lang_feats,) - if language_attentions is not None: - language_attentions = language_attentions + (l_outputs[1],) - - # Run relational layers - for layer_module in self.r_layers: - v_outputs = layer_module( - visual_feats, - visual_attention_mask, - output_attentions, - training=training, - ) - visual_feats = v_outputs[0] - vision_hidden_states = vision_hidden_states + (visual_feats,) - if vision_attentions is not None: - vision_attentions = vision_attentions + (v_outputs[1],) - - # Run cross-modality layers - for layer_module in self.x_layers: - x_outputs = layer_module( - lang_feats, - lang_attention_mask, - visual_feats, - visual_attention_mask, - output_attentions, - training=training, - ) - lang_feats, visual_feats = x_outputs[:2] - vision_hidden_states = vision_hidden_states + (visual_feats,) - language_hidden_states = language_hidden_states + (lang_feats,) - if cross_encoder_attentions is not None: - cross_encoder_attentions = cross_encoder_attentions + (x_outputs[2],) - - visual_encoder_outputs = ( - vision_hidden_states, - vision_attentions if output_attentions else None, - ) - lang_encoder_outputs = ( - language_hidden_states, - language_attentions if output_attentions else None, - ) - - return ( - visual_encoder_outputs, - lang_encoder_outputs, - cross_encoder_attentions if output_attentions else None, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "visn_fc", None) is not None: - with tf.name_scope(self.visn_fc.name): - self.visn_fc.build(None) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - if getattr(self, "x_layers", None) is not None: - for layer in self.x_layers: - with tf.name_scope(layer.name): - layer.build(None) - if getattr(self, "r_layers", None) is not None: - for layer in self.r_layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFLxmertMainLayer(keras.layers.Layer): - config_class = LxmertConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.num_l_layers = config.l_layers - self.num_x_layers = config.x_layers - self.num_r_layers = config.r_layers - self.initializer_range = config.initializer_range - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - self.embeddings = TFLxmertEmbeddings(config, name="embeddings") - self.encoder = TFLxmertEncoder(config, name="encoder") - self.pooler = TFLxmertPooler(config, name="pooler") - self.config = config - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, value): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids=None, - visual_feats=None, - visual_pos=None, - attention_mask=None, - visual_attention_mask=None, - token_type_ids=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - if visual_pos is None or visual_feats is None: - raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.") - - if attention_mask is None: - attention_mask = tf.fill(input_shape, 1) - - if token_type_ids is None: - token_type_ids = tf.fill(input_shape, 0) - - # Positional Word Embeddings - embedding_output = self.embeddings(input_ids, token_type_ids, inputs_embeds, training) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - if visual_attention_mask is not None: - extended_visual_attention_mask = tf.reshape(visual_attention_mask, (input_shape[0], 1, 1, input_shape[1])) - extended_visual_attention_mask = tf.expand_dims(tf.expand_dims(visual_attention_mask, axis=1), axis=1) - - extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, dtype=embedding_output.dtype) - extended_visual_attention_mask = tf.multiply( - tf.subtract(one_cst, extended_visual_attention_mask), ten_thousand_cst - ) - else: - extended_visual_attention_mask = None - - # Run Lxmert encoder - encoder_outputs = self.encoder( - embedding_output, - extended_attention_mask, - visual_feats, - visual_pos, - extended_visual_attention_mask, - output_attentions, - training, - ) - visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2] - vision_hidden_states = visual_encoder_outputs[0] - language_hidden_states = lang_encoder_outputs[0] - - all_attentions = () - if output_attentions: - language_attentions = lang_encoder_outputs[1] - vision_attentions = visual_encoder_outputs[1] - cross_encoder_attentions = encoder_outputs[2] - all_attentions = ( - language_attentions, - vision_attentions, - cross_encoder_attentions, - ) - - hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else () - - visual_output = vision_hidden_states[-1] - lang_output = language_hidden_states[-1] - pooled_output = self.pooler(lang_output) - - if not return_dict: - return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions - - return TFLxmertModelOutput( - pooled_output=pooled_output, - language_output=lang_output, - vision_output=visual_output, - language_hidden_states=language_hidden_states if output_hidden_states else None, - vision_hidden_states=vision_hidden_states if output_hidden_states else None, - language_attentions=language_attentions if output_attentions else None, - vision_attentions=vision_attentions if output_attentions else None, - cross_encoder_attentions=cross_encoder_attentions if output_attentions else None, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -class TFLxmertPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = LxmertConfig - base_model_prefix = "lxmert" - - @property - def dummy_inputs(self): - """ - Dummy inputs to build the network. - - Returns: - tf.Tensor with dummy inputs - """ - batch_size = 2 - num_visual_features = 10 - input_ids = tf.constant([[3, 5, 6], [2, 3, 4]], dtype=tf.int32) - visual_feats = tf.random.uniform((batch_size, num_visual_features, self.config.visual_feat_dim)) - visual_pos = tf.random.uniform((batch_size, num_visual_features, 4)) - - return { - "input_ids": input_ids, - "visual_feats": visual_feats, - "visual_pos": visual_pos, - } - - @property - def input_signature(self): - return { - "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), - "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), - "visual_feats": tf.TensorSpec((None, None, self.config.visual_feat_dim), tf.float32, name="visual_feats"), - "visual_pos": tf.TensorSpec((None, None, 4), tf.float32, name="visual_pos"), - "visual_attention_mask": tf.TensorSpec((None, None), tf.int32, name="visual_attention_mask"), - "token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"), - } - - -LXMERT_START_DOCSTRING = r""" - - The LXMERT model was proposed in [LXMERT: Learning Cross-Modality Encoder Representations from - Transformers](https://huggingface.co/papers/1908.07490) by Hao Tan and Mohit Bansal. It's a vision and language transformer - model, pre-trained on a variety of multi-modal datasets comprising of GQA, VQAv2.0, MCSCOCO captions, and Visual - genome, using a combination of masked language modeling, region of interest feature regression, cross entropy loss - for question answering attribute prediction, and object tag prediction. - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`LxmertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -LXMERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray` or `tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - visual_feats (`tf.Tensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`): - This input represents visual features. They ROI pooled object features from bounding boxes using a - faster-RCNN model) - - These are currently not provided by the transformers library. - visual_pos (`tf.Tensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`): - This input represents spatial features corresponding to their relative (via index) visual features. The - pre-trained LXMERT model expects these spatial features to be normalized bounding boxes on a scale of 0 to - 1. - - These are currently not provided by the transformers library. - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - visual_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - MMask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare Lxmert Model transformer outputting raw hidden-states without any specific head on top.", - LXMERT_START_DOCSTRING, -) -class TFLxmertModel(TFLxmertPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.lxmert = TFLxmertMainLayer(config, name="lxmert") - - @unpack_inputs - @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFLxmertModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - visual_feats: tf.Tensor | None = None, - visual_pos: tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - visual_attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple | TFLxmertModelOutput: - outputs = self.lxmert( - input_ids, - visual_feats, - visual_pos, - attention_mask, - visual_attention_mask, - token_type_ids, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict, - training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "lxmert", None) is not None: - with tf.name_scope(self.lxmert.name): - self.lxmert.build(None) - - -class TFLxmertPooler(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->Lxmert -class TFLxmertPredictionHeadTransform(keras.layers.Layer): - def __init__(self, config: LxmertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - - if isinstance(config.hidden_act, str): - self.transform_act_fn = get_tf_activation(config.hidden_act) - else: - self.transform_act_fn = config.hidden_act - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(inputs=hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead with Bert->Lxmert -class TFLxmertLMPredictionHead(keras.layers.Layer): - def __init__(self, config: LxmertConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - - self.transform = TFLxmertPredictionHeadTransform(config, name="transform") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.input_embeddings = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "transform", None) is not None: - with tf.name_scope(self.transform.name): - self.transform.build(None) - - def get_output_embeddings(self) -> keras.layers.Layer: - return self.input_embeddings - - def set_output_embeddings(self, value: tf.Variable): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self) -> dict[str, tf.Variable]: - return {"bias": self.bias} - - def set_bias(self, value: tf.Variable): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.transform(hidden_states=hidden_states) - seq_length = shape_list(hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->Lxmert -class TFLxmertMLMHead(keras.layers.Layer): - def __init__(self, config: LxmertConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.predictions = TFLxmertLMPredictionHead(config, input_embeddings, name="predictions") - - def call(self, sequence_output: tf.Tensor) -> tf.Tensor: - prediction_scores = self.predictions(hidden_states=sequence_output) - - return prediction_scores - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - -class TFLxmertPreTrainingHeads(keras.layers.Layer): - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - self.predictions = TFLxmertLMPredictionHead(config, input_embeddings, name="predictions") - - self.seq_relationship = keras.layers.Dense( - 2, - kernel_initializer=get_initializer(config.initializer_range), - name="seq_relationship", - ) - self.config = config - - def call(self, sequence_output, pooled_output): - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - if getattr(self, "seq_relationship", None) is not None: - with tf.name_scope(self.seq_relationship.name): - self.seq_relationship.build([None, None, self.config.hidden_size]) - - -class TFLxmertVisualAnswerHead(keras.layers.Layer): - def __init__(self, config, num_labels, **kwargs): - super().__init__(**kwargs) - hid_dim = config.hidden_size - self.dense = keras.layers.Dense( - hid_dim * 2, - kernel_initializer=get_initializer(config.initializer_range), - name="logit_fc_._0", - ) - self.activation = get_tf_activation("gelu") - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="logit_fc_._2") - self.dense_1 = keras.layers.Dense( - num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="logit_fc_._3", - ) - self.hid_dim = hid_dim - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.dense_1(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.hid_dim]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, self.hid_dim * 2]) - if getattr(self, "dense_1", None) is not None: - with tf.name_scope(self.dense_1.name): - self.dense_1.build([None, None, self.hid_dim * 2]) - - -class TFLxmertVisualObjHead(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.transform = TFLxmertPredictionHeadTransform(config, name="transform") - - # Decide the use of visual losses - visual_losses = {} - if config.visual_obj_loss: - visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels} - if config.visual_attr_loss: - visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels} - if config.visual_feat_loss: - visual_losses["feat"] = {"shape": (-1, 2048), "num": config.visual_feat_dim} - self.visual_losses = visual_losses - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder_dict = { - key: keras.layers.Dense( - self.visual_losses[key]["num"], - kernel_initializer=get_initializer(config.initializer_range), - name=f"decoder_dict.{key}", - ) - for key in self.visual_losses - } - self.config = config - - def call(self, hidden_states): - hidden_states = self.transform(hidden_states) - output = {} - for key in self.visual_losses: - output[key] = self.decoder_dict[key](hidden_states) - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transform", None) is not None: - with tf.name_scope(self.transform.name): - self.transform.build(None) - if getattr(self, "decoder_dict", None) is not None: - for layer in self.decoder_dict.values(): - with tf.name_scope(layer.name): - layer.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings("""Lxmert Model with a `language modeling` head on top.""", LXMERT_START_DOCSTRING) -class TFLxmertForPreTraining(TFLxmertPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.config = config - self.num_qa_labels = config.num_qa_labels - self.visual_loss_normalizer = config.visual_loss_normalizer - - # Use of pretraining tasks - self.task_mask_lm = config.task_mask_lm - self.task_obj_predict = config.task_obj_predict - self.task_matched = config.task_matched - self.task_qa = config.task_qa - - # Lxmert backbone - self.lxmert = TFLxmertMainLayer(config, name="lxmert") - - # Pre-training heads - self.cls = TFLxmertPreTrainingHeads(config, self.lxmert.embeddings, name="cls") - if self.task_obj_predict: - self.obj_predict_head = TFLxmertVisualObjHead(config, name="obj_predict_head") - if self.task_qa: - self.answer_head = TFLxmertVisualAnswerHead(config, self.num_qa_labels, name="answer_head") - - # Loss functions - self.loss_fcts = { - "l2": keras.losses.Huber(delta=1.0, name="huber_loss"), - "visn_ce": keras.losses.SparseCategoricalCrossentropy(from_logits=True), - "ce": keras.losses.SparseCategoricalCrossentropy(from_logits=True), - } - - visual_losses = {} - if config.visual_obj_loss: - visual_losses["obj"] = { - "shape": (-1,), - "num": config.num_object_labels, - "loss": "visn_ce", - } - if config.visual_attr_loss: - visual_losses["attr"] = { - "shape": (-1,), - "num": config.num_attr_labels, - "loss": "visn_ce", - } - if config.visual_feat_loss: - visual_losses["feat"] = { - "shape": (-1, config.visual_feat_dim), - "num": config.visual_feat_dim, - "loss": "l2", - } - self.visual_losses = visual_losses - - @property - def dummy_inputs(self): - """ - Dummy inputs to build the network. - - Returns: - tf.Tensor with dummy inputs - """ - batch_size = 2 - num_visual_features = 10 - input_ids = tf.constant([[3, 5, 6], [2, 3, 4]], dtype=tf.int32) - visual_feats = tf.random.uniform((batch_size, num_visual_features, self.config.visual_feat_dim)) - visual_pos = tf.random.uniform((batch_size, num_visual_features, 4)) - - if self.config.task_obj_predict: - obj_labels = {} - if self.config.visual_attr_loss and self.config.task_obj_predict: - obj_labels["attr"] = ( - tf.ones([batch_size, num_visual_features]), - tf.ones([batch_size, num_visual_features]), - ) - if self.config.visual_feat_loss and self.config.task_obj_predict: - obj_labels["feat"] = ( - tf.ones([batch_size, num_visual_features, self.config.visual_feat_dim]), - tf.ones([batch_size, num_visual_features]), - ) - if self.config.visual_obj_loss and self.config.task_obj_predict: - obj_labels["obj"] = ( - tf.ones([batch_size, num_visual_features]), - tf.ones([batch_size, num_visual_features]), - ) - - return { - **{ - "input_ids": input_ids, - "visual_feats": visual_feats, - "visual_pos": visual_pos, - }, - **({"obj_labels": obj_labels} if self.config.task_obj_predict else {}), - } - - def get_lm_head(self): - return self.cls.predictions - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - visual_feats: tf.Tensor | None = None, - visual_pos: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - visual_attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - masked_lm_labels: tf.Tensor | None = None, - obj_labels: dict[str, tuple[tf.Tensor, tf.Tensor]] | None = None, - matched_label: tf.Tensor | None = None, - ans: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFLxmertForPreTrainingOutput: - r""" - masked_lm_labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - obj_labels (`dict[Str: tuple[tf.Tensor, tf.Tensor]]`, *optional*, defaults to `None`): - each key is named after each one of the visual losses and each element of the tuple is of the shape - `(batch_size, num_features)` and `(batch_size, num_features, visual_feature_dim)` for each the label id and - the label score respectively - matched_label (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the whether or not the text input matches the image (classification) loss. Input - should be a sequence pair (see `input_ids` docstring) Indices should be in `[0, 1]`: - - - 0 indicates that the sentence does not match the image, - - 1 indicates that the sentence does match the image. - ans (`tf.Tensor` of shape `(batch_size)`, *optional*, defaults to `None`): - a one hot representation hof the correct answer *optional* - - Returns: - """ - - lxmert_output = self.lxmert( - input_ids, - visual_feats, - visual_pos, - attention_mask, - visual_attention_mask, - token_type_ids, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict, - training, - ) - - lang_output, visual_output, pooled_output = ( - lxmert_output[0], - lxmert_output[1], - lxmert_output[2], - ) - lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output) - if self.task_qa: - answer_score = self.answer_head(pooled_output) - else: - answer_score = pooled_output[0][0] - - total_loss = ( - None - if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None) - else tf.constant(0.0) - ) - losses = () - if masked_lm_labels is not None and self.task_mask_lm: - masked_lm_loss = self.loss_fcts["ce"]( - tf.reshape(masked_lm_labels, [-1]), - tf.reshape(lang_prediction_scores, [-1, self.config.vocab_size]), - ) - total_loss += masked_lm_loss - losses += (masked_lm_loss,) - if matched_label is not None and self.task_matched: - matched_loss = self.loss_fcts["ce"]( - tf.reshape(matched_label, [-1]), - tf.reshape(cross_relationship_score, [-1, 2]), - ) - total_loss += matched_loss - losses += (matched_loss,) - if obj_labels is not None and self.task_obj_predict: - total_visn_loss = 0.0 - visn_prediction_scores_dict = self.obj_predict_head(visual_output) - for key, key_info in self.visual_losses.items(): - label, mask_conf = obj_labels[key] - output_dim = key_info["num"] - loss_fct_name = key_info["loss"] - label_shape = key_info["shape"] - weight = self.visual_loss_normalizer - visn_loss_fct = self.loss_fcts[loss_fct_name] - visn_prediction_scores = visn_prediction_scores_dict[key] - visn_loss = visn_loss_fct( - tf.reshape(label, label_shape), - tf.reshape(visn_prediction_scores, [-1, output_dim]), - ) - - if visn_loss.ndim > 1: # Regression Losses - visn_loss = tf.reduce_mean(visn_loss) - visn_loss = tf.reduce_mean(visn_loss * tf.cast(tf.reshape(mask_conf, [-1]), visn_loss.dtype)) * weight - total_visn_loss += visn_loss - losses += (visn_loss,) - total_loss += total_visn_loss - if ans is not None and self.task_qa: - answer_loss = self.loss_fcts["ce"]( - tf.reshape(ans, [-1]), tf.reshape(answer_score, [-1, self.num_qa_labels]) - ) - # exclude "*2" here to match the effect of QA losses. - # Previous: (loss *0) for 6 epochs, (loss *2) for 6 epochs. (Used 10 instead of 6 in EMNLP paper) - # Now : (loss *1) for 12 epochs - # - # * 2 # Multiply by 2 because > half of the data will not have label - total_loss += answer_loss - losses += (answer_loss,) - # return total_loss, tf.stack(losses)[tf.new_axis, ...], answer_score.detach() - - if not return_dict: - output = ( - lang_prediction_scores, - cross_relationship_score, - answer_score, - ) + lxmert_output[3:] - return ((total_loss,) + output) if total_loss is not None else output - - return TFLxmertForPreTrainingOutput( - loss=total_loss, - prediction_logits=lang_prediction_scores, - cross_relationship_score=cross_relationship_score, - question_answering_score=answer_score, - language_hidden_states=lxmert_output.language_hidden_states, - vision_hidden_states=lxmert_output.vision_hidden_states, - language_attentions=lxmert_output.language_attentions, - vision_attentions=lxmert_output.vision_attentions, - cross_encoder_attentions=lxmert_output.cross_encoder_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "lxmert", None) is not None: - with tf.name_scope(self.lxmert.name): - self.lxmert.build(None) - if getattr(self, "cls", None) is not None: - with tf.name_scope(self.cls.name): - self.cls.build(None) - if getattr(self, "obj_predict_head", None) is not None: - with tf.name_scope(self.obj_predict_head.name): - self.obj_predict_head.build(None) - if getattr(self, "answer_head", None) is not None: - with tf.name_scope(self.answer_head.name): - self.answer_head.build(None) - - -__all__ = [ - "TFLxmertForPreTraining", - "TFLxmertMainLayer", - "TFLxmertModel", - "TFLxmertPreTrainedModel", - "TFLxmertVisualFeatureEncoder", -] diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py deleted file mode 100644 index e1d9bea4fcdb..000000000000 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ /dev/null @@ -1,1500 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Marian Team Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax Marian model.""" - -import math -import random -from functools import partial -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxSeq2SeqLMOutput, - FlaxSeq2SeqModelOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_marian import MarianConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de" -_CONFIG_FOR_DOC = "MarianConfig" - - -MARIAN_START_DOCSTRING = r""" - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`MarianConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -MARIAN_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -MARIAN_ENCODE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -MARIAN_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def create_sinusoidal_positions(n_pos, dim): - position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) - sentinel = dim // 2 + dim % 2 - out = np.zeros_like(position_enc) - out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) - out[:, sentinel:] = np.cos(position_enc[:, 1::2]) - - return jnp.array(out) - - -# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: - """ - Shift input ids one token to the right. - """ - shifted_input_ids = jnp.zeros_like(input_ids) - shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) - shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) - - shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) - return shifted_input_ids - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Marian -class FlaxMarianAttention(nn.Module): - config: MarianConfig - embed_dim: int - num_heads: int - dropout: float = 0.0 - causal: bool = False - bias: bool = True - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self) -> None: - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {self.num_heads})." - ) - - dense = partial( - nn.Dense, - self.embed_dim, - use_bias=self.bias, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() - self.out_proj = dense() - - self.dropout_layer = nn.Dropout(rate=self.dropout) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states: jnp.ndarray, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states) - value_states = self.v_proj(key_value_states) - else: - # self_attention - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayer with Bart->Marian -class FlaxMarianEncoderLayer(nn.Module): - config: MarianConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxMarianAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.encoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) - - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->Marian -class FlaxMarianEncoderLayerCollection(nn.Module): - config: MarianConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxMarianEncoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.encoder_layers) - ] - self.layerdrop = self.config.encoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): # skip the layer - layer_outputs = (None, None) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions, - deterministic, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayer with Bart->Marian -class FlaxMarianDecoderLayer(nn.Module): - config: MarianConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxMarianAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - causal=True, - dtype=self.dtype, - ) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.encoder_attn = FlaxMarianAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.fc1 = nn.Dense( - self.config.decoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - hidden_states, cross_attn_weights = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # Fully Connected - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Marian -class FlaxMarianDecoderLayerCollection(nn.Module): - config: MarianConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxMarianDecoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.decoder_layers) - ] - self.layerdrop = self.config.decoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): - layer_outputs = (None, None, None) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - deterministic=deterministic, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class FlaxMarianEncoder(nn.Module): - config: MarianConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.max_source_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 - - self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim) - self.layers = FlaxMarianEncoderLayerCollection(self.config, self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - positions = jnp.take(self.embed_positions, position_ids, axis=0) - # explicitly cast the positions here, since self.embed_positions are not registered as parameters - positions = positions.astype(inputs_embeds.dtype) - - hidden_states = inputs_embeds + positions - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return outputs - - return FlaxBaseModelOutput( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxMarianDecoder(nn.Module): - config: MarianConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.max_target_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 - - self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim) - self.layers = FlaxMarianDecoderLayerCollection(self.config, self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - # embed positions - positions = jnp.take(self.embed_positions, position_ids, axis=0) - # explicitly cast the positions here, since self.embed_positions are not registered as parameters - positions = positions.astype(inputs_embeds.dtype) - - hidden_states = inputs_embeds + positions - - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return outputs - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -class FlaxMarianModule(nn.Module): - config: MarianConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.encoder = FlaxMarianEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - self.decoder = FlaxMarianDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -class FlaxMarianPreTrainedModel(FlaxPreTrainedModel): - config_class = MarianConfig - base_model_prefix: str = "model" - module_class: nn.Module = None - - def __init__( - self, - config: MarianConfig, - input_shape: tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - # make sure initialization pass will work for FlaxMarianForSequenceClassificationModule - input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) - attention_mask = jnp.ones_like(input_ids) - decoder_input_ids = input_ids - decoder_attention_mask = jnp.ones_like(input_ids) - - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module(decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(MARIAN_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=MarianConfig) - def encode( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxMarianMTModel - - >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") - >>> model = FlaxMarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=64, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_ids, attention_mask, position_ids, **kwargs) - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - @add_start_docstrings(MARIAN_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=MarianConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoTokenizer, FlaxMarianMTModel - - >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") - >>> model = FlaxMarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=64, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> last_decoder_hidden_states = outputs.last_hidden_state - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxMarianAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - @add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # prepare decoder inputs - if decoder_input_ids is None: - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - if decoder_position_ids is None: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - -@add_start_docstrings( - "The bare Marian Model transformer outputting raw hidden-states without any specific head on top.", - MARIAN_START_DOCSTRING, -) -class FlaxMarianModel(FlaxMarianPreTrainedModel): - config: MarianConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - module_class = FlaxMarianModule - - -append_call_sample_docstring(FlaxMarianModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) - - -class FlaxMarianMTModule(nn.Module): - config: MarianConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.model = FlaxMarianModule(config=self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.model.shared.num_embeddings, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = self.model.variables["params"]["shared"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - lm_logits += self.final_logits_bias.astype(self.dtype) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return output - - return FlaxSeq2SeqLMOutput( - logits=lm_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - "The MARIAN Model with a language modeling head. Can be used for translation.", MARIAN_START_DOCSTRING -) -class FlaxMarianMTModel(FlaxMarianPreTrainedModel): - module_class = FlaxMarianMTModule - dtype: jnp.dtype = jnp.float32 - - @add_start_docstrings(MARIAN_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=MarianConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoTokenizer, FlaxMarianMTModel - - >>> model = FlaxMarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de") - >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=64, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxMarianAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - outputs = decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = module.model.variables["params"]["shared"]["embedding"] - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = module.lm_head(hidden_states) - lm_logits += module.final_logits_bias.astype(self.dtype) - - return lm_logits, outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def _adapt_logits_for_beam_search(self, logits): - """This function enforces the padding token never to be generated.""" - logits = logits.at[:, :, self.config.pad_token_id].set(float("-inf")) - return logits - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 - return model_kwargs - - -FLAX_MARIAN_MT_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxMarianMTModel - - >>> model = FlaxMarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de") - >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") - - >>> text = "My friends are cool but they eat too many carbs." - >>> input_ids = tokenizer(text, max_length=64, return_tensors="jax").input_ids - - >>> sequences = model.generate(input_ids, max_length=64, num_beams=2).sequences - - >>> outputs = tokenizer.batch_decode(sequences, skip_special_tokens=True) - >>> # should give *Meine Freunde sind cool, aber sie essen zu viele Kohlenhydrate.* - ``` -""" - -overwrite_call_docstring( - FlaxMarianMTModel, - MARIAN_INPUTS_DOCSTRING + FLAX_MARIAN_MT_DOCSTRING, -) -append_replace_return_docstrings(FlaxMarianMTModel, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - - -__all__ = ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel"] diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py deleted file mode 100644 index c989cfa15f5a..000000000000 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ /dev/null @@ -1,1558 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Marian Team Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 Marian model.""" - -from __future__ import annotations - -import random - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPastAndCrossAttentions, - TFSeq2SeqLMOutput, - TFSeq2SeqModelOutput, -) - -# Public API -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFPreTrainedModel, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_end_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_marian import MarianConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de" -_CONFIG_FOR_DOC = "MarianConfig" - - -LARGE_NEGATIVE = -1e8 - - -# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right -def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - pad_token_id = tf.cast(pad_token_id, input_ids.dtype) - decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) - start_tokens = tf.fill( - (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) - ) - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz = input_ids_shape[0] - tgt_len = input_ids_shape[1] - mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE - mask_cond = tf.range(shape_list(mask)[-1]) - - mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) - - return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -class TFMarianSinusoidalPositionalEmbedding(keras.layers.Layer): - """This module produces sinusoidal positional embeddings of any length.""" - - def __init__(self, num_positions: int, embedding_dim: int, **kwargs): - super().__init__(**kwargs) - - if embedding_dim % 2 != 0: - raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") - - self.embedding_dim = embedding_dim - self.num_positions = num_positions - - def build(self, input_shape: tf.TensorShape): - """ - Build shared token embedding layer Shared weights logic adapted from - https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 - """ - - weight = self._init_weight(self.num_positions, self.embedding_dim) - - self.weight = self.add_weight( - name="embeddings", - shape=[self.num_positions, self.embedding_dim], - ) - weight = tf.cast(weight, dtype=self.weight.dtype) - - self.weight.assign(weight) - - super().build(input_shape) - - @staticmethod - def _init_weight(n_pos: int, dim: int): - """ - Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in - the 2nd half of the vector. [dim // 2:] - """ - position_enc = np.array( - [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] - ) - table = np.zeros_like(position_enc) - # index 0 is all zero - table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) - table[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) - # convert to tensor - table = tf.convert_to_tensor(table) - tf.stop_gradient(table) - return table - - def call( - self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None - ): - """Input is expected to be of size [bsz x seqlen].""" - if position_ids is None: - seq_len = input_shape[1] - position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") - return tf.gather(self.weight, position_ids) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Marian -class TFMarianAttention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartEncoderLayer with Bart->Marian -class TFMarianEncoderLayer(keras.layers.Layer): - def __init__(self, config: MarianConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFMarianAttention( - self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" - ) - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: np.ndarray | tf.Tensor | None, - layer_head_mask: tf.Tensor | None, - training: bool | None = False, - ) -> tf.Tensor: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)` - """ - residual = hidden_states - hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask - ) - - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) - - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - return hidden_states, self_attn_weights - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.encoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartDecoderLayer with Bart->Marian -class TFMarianDecoderLayer(keras.layers.Layer): - def __init__(self, config: MarianConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFMarianAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - name="self_attn", - is_decoder=True, - ) - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.encoder_attn = TFMarianAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - name="encoder_attn", - is_decoder=True, - ) - self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") - self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - cross_attn_layer_head_mask: tf.Tensor | None = None, - past_key_value: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - encoder_hidden_states (`tf.Tensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` - encoder_attention_mask (`tf.Tensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - `(decoder_attention_heads,)` - cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. - `(decoder_attention_heads,)` - past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states - """ - residual = hidden_states - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - return ( - hidden_states, - self_attn_weights, - cross_attn_weights, - present_key_value, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "encoder_attn", None) is not None: - with tf.name_scope(self.encoder_attn.name): - self.encoder_attn.build(None) - if getattr(self, "encoder_attn_layer_norm", None) is not None: - with tf.name_scope(self.encoder_attn_layer_norm.name): - self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.decoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFMarianPreTrainedModel(TFPreTrainedModel): - config_class = MarianConfig - base_model_prefix = "model" - - -MARIAN_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`MarianConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -MARIAN_GENERATION_EXAMPLE = r""" - TF version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints. Available - models are listed [here](https://huggingface.co/models?search=Helsinki-NLP). - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TFMarianMTModel - >>> from typing import List - - >>> src = "fr" # source language - >>> trg = "en" # target language - >>> sample_text = "où est l'arrêt de bus ?" - >>> model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}" - - >>> model = TFMarianMTModel.from_pretrained(model_name) - >>> tokenizer = AutoTokenizer.from_pretrained(model_name) - >>> batch = tokenizer([sample_text], return_tensors="tf") - >>> gen = model.generate(**batch) - >>> tokenizer.batch_decode(gen, skip_special_tokens=True) - "Where is the bus stop ?" - ``` -""" - -MARIAN_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - Marian uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If - `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. - decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tf.FloatTensor`, *optional*): - hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - of shape `(batch_size, sequence_length, hidden_size)` is a sequence of - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@keras_serializable -class TFMarianEncoder(keras.layers.Layer): - config_class = MarianConfig - """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`TFMarianEncoderLayer`]. - - Args: - config: MarianConfig - """ - - def __init__(self, config: MarianConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.dropout = keras.layers.Dropout(config.dropout) - self.layerdrop = config.encoder_layerdrop - self.padding_idx = config.pad_token_id - self.max_source_positions = config.max_position_embeddings - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - - self.embed_tokens = embed_tokens - self.embed_positions = TFMarianSinusoidalPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.layers = [TFMarianEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ): - """ - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value - in the config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. This argument can be used only in eager mode, in graph mode the value in the config - will be used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(input_shape) - hidden_states = inputs_embeds + embed_pos - hidden_states = self.dropout(hidden_states, training=training) - - # check attention mask and invert - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask) - else: - attention_mask = None - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - tf.debugging.assert_equal( - shape_list(head_mask)[0], - len(self.layers), - message=( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(head_mask)[0]}." - ), - ) - - # encoder layers - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): # skip the layer - continue - - hidden_states, attn = encoder_layer( - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - ) - - if output_attentions: - all_attentions += (attn,) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFMarianDecoder(keras.layers.Layer): - config_class = MarianConfig - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFMarianDecoderLayer`] - - Args: - config: MarianConfig - embed_tokens: output embedding - """ - - def __init__(self, config: MarianConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.padding_idx = config.pad_token_id - self.embed_tokens = embed_tokens - self.layerdrop = config.decoder_layerdrop - self.embed_positions = TFMarianSinusoidalPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - self.layers = [TFMarianDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] - - self.dropout = keras.layers.Dropout(config.dropout) - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ): - r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up - decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value - in the config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. This argument can be used only in eager mode, in graph mode the value in the config - will be used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 - - # embed positions - if position_ids is None: - positions = self.embed_positions(input_shape, past_key_values_length) - else: - positions = self.embed_positions(input_shape, position_ids=position_ids) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - hidden_states = inputs_embeds - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) - else: - combined_attention_mask = _expand_mask( - tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] - ) - - if attention_mask is not None: - combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) - - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) - - hidden_states = self.dropout(hidden_states + positions, training=training) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None - present_key_values = () if use_cache else None - - # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired - for attn_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: - if attn_mask is not None: - tf.debugging.assert_equal( - shape_list(attn_mask)[0], - len(self.layers), - message=( - f"The {attn_name} should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(attn_mask)[0]}." - ), - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - dropout_probability = random.uniform(0, 1) - - if training and (dropout_probability < self.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - past_key_value=past_key_value, - ) - - if use_cache: - present_key_values += (present_key_value,) - - if output_attentions: - all_self_attns += (layer_self_attn,) - - if encoder_hidden_states is not None: - all_cross_attns += (layer_cross_attn,) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns - else: - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFMarianMainLayer(keras.layers.Layer): - config_class = MarianConfig - - def __init__(self, config: MarianConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.shared = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.d_model, - embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), - name="model.shared", - ) - # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) - self.shared.load_weight_prefix = "model.shared" - - self.encoder = TFMarianEncoder(config, self.shared, name="encoder") - self.decoder = TFMarianDecoder(config, self.shared, name="decoder") - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.embed_tokens = self.shared - self.decoder.embed_tokens = self.shared - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - decoder_position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, - ): - if decoder_input_ids is None and decoder_inputs_embeds is None: - use_cache = False - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): - encoder_outputs = TFBaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not return_dict and not isinstance(encoder_outputs, tuple): - encoder_outputs = encoder_outputs.to_tuple() - - decoder_outputs = self.decoder( - decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return TFSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - # The shared/tied weights expect to be in the model base namespace - # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than - # the current one. - with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): - self.shared.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - "The bare MARIAN Model outputting raw hidden-states without any specific head on top.", - MARIAN_START_DOCSTRING, -) -class TFMarianModel(TFMarianPreTrainedModel): - def __init__(self, config: MarianConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.model = TFMarianMainLayer(config, name="model") - - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSeq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - decoder_position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - encoder_outputs: tf.Tensor | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, - ) -> tuple[tf.Tensor] | TFSeq2SeqModelOutput: - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqModelOutput( - last_hidden_state=output.last_hidden_state, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer -class BiasLayer(keras.layers.Layer): - """ - Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, - so all weights have to be registered in a layer. - """ - - def __init__(self, shape, initializer, trainable, name, **kwargs): - super().__init__(name=name, **kwargs) - # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of - # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: - # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 - self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) - - def call(self, x): - return x + self.bias - - -@add_start_docstrings( - "The MARIAN Model with a language modeling head. Can be used for summarization.", - MARIAN_START_DOCSTRING, -) -class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): - _keys_to_ignore_on_load_unexpected = [ - r"model.encoder.embed_tokens.weight", - r"model.decoder.embed_tokens.weight", - ] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.model = TFMarianMainLayer(config, name="model") - self.use_cache = config.use_cache - # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False - ) - - def get_decoder(self): - return self.model.decoder - - def get_encoder(self): - return self.model.encoder - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - def get_bias(self): - return {"final_logits_bias": self.bias_layer.bias} - - def set_bias(self, value): - # Replaces the existing layers containing bias for correct (de)serialization. - vocab_size = value["final_logits_bias"].shape[-1] - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False - ) - self.bias_layer.bias.assign(value["final_logits_bias"]) - - @unpack_inputs - @add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - @add_end_docstrings(MARIAN_GENERATION_EXAMPLE) - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - decoder_position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - encoder_outputs: TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFSeq2SeqLMOutput: - r""" - labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - """ - - if labels is not None: - labels = tf.where( - labels == self.config.pad_token_id, - tf.fill(shape_list(labels), tf.cast(-100, labels.dtype)), - labels, - ) - use_cache = False - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) - lm_logits = self.bias_layer(lm_logits) - masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - return TFSeq2SeqLMOutput( - loss=masked_lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, # index 1 of d outputs - decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs - decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs - cross_attentions=outputs.cross_attentions, # index 4 of d outputs - encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs - encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out - encoder_attentions=outputs.encoder_attentions, # 2 of e out - ) - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - if decoder_attention_mask is not None: # xla - decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] - else: # no xla + no past_key_values - decoder_position_ids = tf.range(decoder_input_ids.shape[1]) - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "decoder_position_ids": decoder_position_ids, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - - def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "bias_layer", None) is not None: - with tf.name_scope(self.bias_layer.name): - self.bias_layer.build(None) - - -__all__ = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"] diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py deleted file mode 100644 index ad269860a959..000000000000 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ /dev/null @@ -1,1780 +0,0 @@ -# coding=utf-8 -# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax MBart model.""" - -import math -import random -from functools import partial -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxSeq2SeqLMOutput, - FlaxSeq2SeqModelOutput, - FlaxSeq2SeqQuestionAnsweringModelOutput, - FlaxSeq2SeqSequenceClassifierOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_mbart import MBartConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25" -_CONFIG_FOR_DOC = "MBartConfig" - - -MBART_START_DOCSTRING = r""" - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`MBartConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -MBART_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -MBART_ENCODE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -MBART_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray: - """ - Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not - have a single `decoder_start_token_id` in contrast to other Bart-like models. - """ - prev_output_tokens = jnp.array(input_ids).copy() - - if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") - - # replace possible -100 values in labels by `pad_token_id` - prev_output_tokens = jnp.where(prev_output_tokens == -100, pad_token_id, input_ids) - index_of_eos = (jnp.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1) - decoder_start_tokens = jnp.array( - [prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=jnp.int32 - ).squeeze() - - prev_output_tokens = prev_output_tokens.at[:, 1:].set(prev_output_tokens[:, :-1]) - prev_output_tokens = prev_output_tokens.at[:, 0].set(decoder_start_tokens) - - return prev_output_tokens - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->MBart -class FlaxMBartAttention(nn.Module): - config: MBartConfig - embed_dim: int - num_heads: int - dropout: float = 0.0 - causal: bool = False - bias: bool = True - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self) -> None: - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {self.num_heads})." - ) - - dense = partial( - nn.Dense, - self.embed_dim, - use_bias=self.bias, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() - self.out_proj = dense() - - self.dropout_layer = nn.Dropout(rate=self.dropout) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states: jnp.ndarray, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states) - value_states = self.v_proj(key_value_states) - else: - # self_attention - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class FlaxMBartEncoderLayer(nn.Module): - config: MBartConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxMBartAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.encoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->MBart -class FlaxMBartEncoderLayerCollection(nn.Module): - config: MBartConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxMBartEncoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.encoder_layers) - ] - self.layerdrop = self.config.encoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): # skip the layer - layer_outputs = (None, None) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions, - deterministic, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -class FlaxMBartDecoderLayer(nn.Module): - config: MBartConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxMBartAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - causal=True, - dtype=self.dtype, - ) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.encoder_attn = FlaxMBartAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.fc1 = nn.Dense( - self.config.decoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->MBart -class FlaxMBartDecoderLayerCollection(nn.Module): - config: MBartConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxMBartDecoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.decoder_layers) - ] - self.layerdrop = self.config.decoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): - layer_outputs = (None, None, None) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - deterministic=deterministic, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartClassificationHead with Bart->MBart -class FlaxMBartClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - config: MBartConfig - inner_dim: int - num_classes: int - pooler_dropout: float - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense( - self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.dropout = nn.Dropout(rate=self.pooler_dropout) - self.out_proj = nn.Dense( - self.num_classes, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.dense(hidden_states) - hidden_states = jnp.tanh(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -class FlaxMBartEncoder(nn.Module): - config: MBartConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_source_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 - - # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - self.embed_positions = nn.Embed( - self.config.max_position_embeddings + self.offset, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.layers = FlaxMBartEncoderLayerCollection(self.config, self.dtype) - self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(position_ids + self.offset) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_states = outputs[0] - last_hidden_states = self.layer_norm(last_hidden_states) - - # update the last element in `hidden_states` after applying `layernorm` above - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_states,) - - if not return_dict: - outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=last_hidden_states, - hidden_states=hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxMBartDecoder(nn.Module): - config: MBartConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_target_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 - - # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - self.embed_positions = nn.Embed( - self.config.max_position_embeddings + self.offset, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.layers = FlaxMBartDecoderLayerCollection(self.config, self.dtype) - self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - # embed positions - positions = self.embed_positions(position_ids + self.offset) - - hidden_states = inputs_embeds + positions - hidden_states = self.layernorm_embedding(hidden_states) - - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_states = outputs[0] - last_hidden_states = self.layer_norm(last_hidden_states) - - # update the last element in `hidden_states` after applying `layernorm` above - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_states,) - - if not return_dict: - outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=last_hidden_states, - hidden_states=hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->MBart -class FlaxMBartModule(nn.Module): - config: MBartConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - - self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - self.decoder = FlaxMBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -class FlaxMBartPreTrainedModel(FlaxPreTrainedModel): - config_class = MBartConfig - base_model_prefix: str = "model" - module_class: nn.Module = None - - def __init__( - self, - config: MBartConfig, - input_shape: tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - # make sure initialization pass will work for FlaxMBartForSequenceClassificationModule - input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) - attention_mask = jnp.ones_like(input_ids) - decoder_input_ids = input_ids - decoder_attention_mask = jnp.ones_like(input_ids) - - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->MBart - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(MBART_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=MBartConfig) - def encode( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration - - >>> model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_ids, attention_mask, position_ids, **kwargs) - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - @add_start_docstrings(MBART_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=MBartConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration - - >>> model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> last_decoder_hidden_states = outputs.last_hidden_state - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxMBartAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # prepare decoder inputs - if decoder_input_ids is None: - decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - if decoder_position_ids is None: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - -@add_start_docstrings( - "The bare MBart Model transformer outputting raw hidden-states without any specific head on top.", - MBART_START_DOCSTRING, -) -class FlaxMBartModel(FlaxMBartPreTrainedModel): - config: MBartConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - module_class = FlaxMBartModule - - -append_call_sample_docstring(FlaxMBartModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->MBart -class FlaxMBartForConditionalGenerationModule(nn.Module): - config: MBartConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.model = FlaxMBartModule(config=self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.model.shared.num_embeddings, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = self.model.variables["params"]["shared"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype)) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return output - - return FlaxSeq2SeqLMOutput( - logits=lm_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - "The MMBart Model with a language modeling head. Can be used for summarization.", MBART_START_DOCSTRING -) -class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel): - module_class = FlaxMBartForConditionalGenerationModule - dtype: jnp.dtype = jnp.float32 - - @add_start_docstrings(MBART_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=MBartConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration - - >>> model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxMBartAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - outputs = decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = module.model.variables["params"]["shared"]["embedding"] - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = module.lm_head(hidden_states) - - lm_logits += module.final_logits_bias.astype(self.dtype) - return lm_logits, outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 - return model_kwargs - - -FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = r""" - Returns: - - Summarization example: - - ```python - >>> from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration, MBartConfig - - >>> model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") - - >>> ARTICLE_TO_SUMMARIZE = "Meine Freunde sind cool, aber sie essen zu viel Kuchen." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np") - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=5).sequences - >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) - ``` - - Mask filling example: - - ```python - >>> from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration - - >>> model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") - - >>> # de_DE is the language symbol id for German - >>> TXT = " Meine Freunde sind nett aber sie essen zu viel Kuchen. de_DE" - >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="np")["input_ids"] - - >>> logits = model(input_ids).logits - >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item() - >>> probs = logits[0, masked_index].softmax(dim=0) - >>> values, predictions = probs.topk(5) - - >>> tokenizer.decode(predictions).split() - ``` -""" - -overwrite_call_docstring( - FlaxMBartForConditionalGeneration, MBART_INPUTS_DOCSTRING + FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING -) -append_replace_return_docstrings( - FlaxMBartForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC -) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForSequenceClassificationModule with Bart->MBart -class FlaxMBartForSequenceClassificationModule(nn.Module): - config: MBartConfig - dtype: jnp.dtype = jnp.float32 - num_labels: Optional[int] = None - - def setup(self): - self.model = FlaxMBartModule(config=self.config, dtype=self.dtype) - self.classification_head = FlaxMBartClassificationHead( - config=self.config, - inner_dim=self.config.d_model, - num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels, - pooler_dropout=self.config.classifier_dropout, - ) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] # last hidden state - - eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0) - - # The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation - if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer): - if len(jnp.unique(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - - if any(eos_mask.sum(1) == 0): - raise ValueError("There are missing tokens in input_ids") - - # Ensure to keep 1 only for the last token for each example - eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6 - eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0) - - sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1) - logits = self.classification_head(sentence_representation, deterministic=deterministic) - - if not return_dict: - output = (logits,) + outputs[1:] - return output - - return FlaxSeq2SeqSequenceClassifierOutput( - logits=logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - """ - MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE - tasks. - """, - MBART_START_DOCSTRING, -) -class FlaxMBartForSequenceClassification(FlaxMBartPreTrainedModel): - module_class = FlaxMBartForSequenceClassificationModule - dtype = jnp.float32 - - -append_call_sample_docstring( - FlaxMBartForSequenceClassification, - _CHECKPOINT_FOR_DOC, - FlaxSeq2SeqSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForQuestionAnsweringModule with Bart->MBart -class FlaxMBartForQuestionAnsweringModule(nn.Module): - config: MBartConfig - dtype: jnp.dtype = jnp.float32 - num_labels = 2 - - def setup(self): - self.model = FlaxMBartModule(config=self.config, dtype=self.dtype) - self.qa_outputs = nn.Dense( - self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - output = (start_logits, end_logits) + outputs[1:] - return output - - return FlaxSeq2SeqQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - """ - MBart Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - MBART_START_DOCSTRING, -) -class FlaxMBartForQuestionAnswering(FlaxMBartPreTrainedModel): - module_class = FlaxMBartForQuestionAnsweringModule - dtype = jnp.float32 - - -append_call_sample_docstring( - FlaxMBartForQuestionAnswering, - _CHECKPOINT_FOR_DOC, - FlaxSeq2SeqQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) - - -__all__ = [ - "FlaxMBartForConditionalGeneration", - "FlaxMBartForQuestionAnswering", - "FlaxMBartForSequenceClassification", - "FlaxMBartModel", - "FlaxMBartPreTrainedModel", -] diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py deleted file mode 100644 index ac29bfeac76f..000000000000 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ /dev/null @@ -1,1572 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 MBart model.""" - -from __future__ import annotations - -import random - -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPastAndCrossAttentions, - TFSeq2SeqLMOutput, - TFSeq2SeqModelOutput, -) - -# Public API -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_end_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_mbart import MBartConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25" -_CONFIG_FOR_DOC = "MBartConfig" - - -LARGE_NEGATIVE = -1e8 - - -def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int): - """ - Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not - have a single `decoder_start_token_id` in contrast to other Bart-like models. - """ - if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - input_ids = tf.where( - input_ids == -100, tf.fill(shape_list(input_ids), tf.cast(pad_token_id, input_ids.dtype)), input_ids - ) - language_id_index = ( - tf.reduce_sum(tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=input_ids.dtype), axis=-1) - 1 - ) - language_id_index = tf.stack( - [tf.range(shape_list(input_ids)[0], dtype=input_ids.dtype), language_id_index], axis=-1 - ) - languages_ids = tf.gather_nd(input_ids, language_id_index) - - shifted_input_ids = tf.concat([tf.expand_dims(languages_ids, axis=-1), input_ids[:, :-1]], axis=-1) - - return shifted_input_ids - - -# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz = input_ids_shape[0] - tgt_len = input_ids_shape[1] - mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE - mask_cond = tf.range(shape_list(mask)[-1]) - - mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) - - return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartLearnedPositionalEmbedding with Bart->MBart -class TFMBartLearnedPositionalEmbedding(keras.layers.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): - # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs) - - def call( - self, - input_shape: tf.TensorShape | None = None, - past_key_values_length: int = 0, - position_ids: tf.Tensor | None = None, - ): - """Input is expected to be of size [bsz x seqlen].""" - if position_ids is None: - seq_len = input_shape[1] - position_ids = tf.range(seq_len, delta=1, name="range") - position_ids += past_key_values_length - - offset_dtype = position_ids.dtype if isinstance(position_ids, tf.Tensor) else tf.int32 - return super().call(position_ids + tf.constant(self.offset, dtype=offset_dtype)) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart -class TFMBartAttention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -class TFMBartEncoderLayer(keras.layers.Layer): - def __init__(self, config: MBartConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFMBartAttention( - self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" - ) - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - layer_head_mask: tf.Tensor, - training: bool | None = False, - ): - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* - attention_mask (`tf.Tensor`): attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - *(encoder_attention_heads,)* - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask - ) - - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) - - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - return hidden_states, self_attn_weights - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.encoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFMBartDecoderLayer(keras.layers.Layer): - def __init__(self, config: MBartConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFMBartAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - name="self_attn", - is_decoder=True, - ) - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.encoder_attn = TFMBartAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - name="encoder_attn", - is_decoder=True, - ) - self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") - self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - cross_attn_layer_head_mask: tf.Tensor | None = None, - past_key_value: tuple[tf.Tensor] | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* - attention_mask (`tf.Tensor`): attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - encoder_hidden_states (`tf.Tensor`): - cross attention input to the layer of shape *(batch, seq_len, embed_dim)* - encoder_attention_mask (`tf.Tensor`): encoder attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - *(decoder_attention_heads,)* - cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. - *(decoder_attention_heads,)* - past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - return ( - hidden_states, - self_attn_weights, - cross_attn_weights, - present_key_value, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "encoder_attn", None) is not None: - with tf.name_scope(self.encoder_attn.name): - self.encoder_attn.build(None) - if getattr(self, "encoder_attn_layer_norm", None) is not None: - with tf.name_scope(self.encoder_attn_layer_norm.name): - self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.decoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFMBartPreTrainedModel(TFPreTrainedModel): - config_class = MBartConfig - base_model_prefix = "model" - - -MBART_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`MBartConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -MBART_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - MBart uses a specific language id token as the starting token for `decoder_input_ids` generation that - varies according to source and target language, *e.g.* 25004 for *en_XX*, and 25003 for *de_DE*. If - `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. - decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tf.FloatTensor`, *optional*): - hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - of shape `(batch_size, sequence_length, hidden_size)` is a sequence of - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - -MBART_GENERATION_EXAMPLE = r""" - Translation example: - - ```python - >>> from transformers import AutoTokenizer, TFMBartForConditionalGeneration - - >>> model = TFMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-en-ro") - - >>> example_english_phrase = "42 is the answer" - >>> inputs = tokenizer(example_english_phrase, return_tensors="tf") - - >>> # Translate - >>> generated_ids = model.generate(**inputs, num_beams=4, max_length=5) - >>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - '42 este răspuns' - ``` - - Mask filling example: - - ```python - >>> from transformers import AutoTokenizer, TFMBartForConditionalGeneration - >>> import tensorflow as tf - - >>> model = TFMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") - - >>> # de_DE is the language symbol id for German - >>> TXT = " Meine Freunde sind nett aber sie essen zu viel Kuchen. de_DE" - - >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="tf")["input_ids"] - >>> logits = model(input_ids).logits - - >>> masked_index = tf.where(input_ids[0] == tokenizer.mask_token_id)[0, 0] - >>> probs = tf.nn.softmax(logits[0, masked_index], axis=0) - >>> values, predictions = tf.math.top_k(probs, 5) - - >>> tokenizer.decode(predictions).split() - ['nett', 'sehr', 'ganz', 'nicht', 'so'] - ``` -""" - - -@keras_serializable -class TFMBartEncoder(keras.layers.Layer): - config_class = MBartConfig - """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`TFMBartEncoderLayer`]. - - Args: - config: MBartConfig - """ - - def __init__(self, config: MBartConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.dropout = keras.layers.Dropout(config.dropout) - self.layerdrop = config.encoder_layerdrop - self.padding_idx = config.pad_token_id - self.max_source_positions = config.max_position_embeddings - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - - self.embed_tokens = embed_tokens - self.embed_positions = TFMBartLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.layers = [TFMBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] - self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - self.embed_dim = config.d_model - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - inputs_embeds: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - """ - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value - in the config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. This argument can be used only in eager mode, in graph mode the value in the config - will be used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(input_shape) - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - - # check attention mask and invert - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask) - else: - attention_mask = None - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - tf.debugging.assert_equal( - shape_list(head_mask)[0], - len(self.layers), - message=( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(head_mask)[0]}." - ), - ) - - # encoder layers - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): # skip the layer - continue - - hidden_states, attn = encoder_layer( - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - ) - - if output_attentions: - all_attentions += (attn,) - - hidden_states = self.layer_norm(hidden_states) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layernorm_embedding", None) is not None: - with tf.name_scope(self.layernorm_embedding.name): - self.layernorm_embedding.build([None, None, self.embed_dim]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFMBartDecoder(keras.layers.Layer): - config_class = MBartConfig - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFMBartDecoderLayer`] - - Args: - config: MBartConfig - embed_tokens: output embedding - """ - - def __init__(self, config: MBartConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.padding_idx = config.pad_token_id - self.embed_tokens = embed_tokens - self.layerdrop = config.decoder_layerdrop - self.embed_positions = TFMBartLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - self.layers = [TFMBartDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] - self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - - self.dropout = keras.layers.Dropout(config.dropout) - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType = None, - inputs_embeds: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: - r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up - decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value - in the config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. This argument can be used only in eager mode, in graph mode the value in the config - will be used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 - - # embed positions - if position_ids is None: - positions = self.embed_positions(input_shape, past_key_values_length) - else: - positions = self.embed_positions(input_shape, position_ids=position_ids) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - hidden_states = inputs_embeds - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) - else: - combined_attention_mask = _expand_mask( - tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] - ) - - if attention_mask is not None: - combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) - - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) - - hidden_states = self.layernorm_embedding(hidden_states + positions) - hidden_states = self.dropout(hidden_states, training=training) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None - present_key_values = () if use_cache else None - - # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired - for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: - if attn_mask is not None: - tf.debugging.assert_equal( - shape_list(attn_mask)[0], - len(self.layers), - message=( - f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(attn_mask)[0]}." - ), - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - dropout_probability = random.uniform(0, 1) - - if training and (dropout_probability < self.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - past_key_value=past_key_value, - ) - - if use_cache: - present_key_values += (present_key_value,) - - if output_attentions: - all_self_attns += (layer_self_attn,) - - if encoder_hidden_states is not None: - all_cross_attns += (layer_cross_attn,) - - hidden_states = self.layer_norm(hidden_states) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns - else: - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layernorm_embedding", None) is not None: - with tf.name_scope(self.layernorm_embedding.name): - self.layernorm_embedding.build([None, None, self.config.d_model]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFMBartMainLayer(keras.layers.Layer): - config_class = MBartConfig - - def __init__(self, config: MBartConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.shared = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.d_model, - embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), - name="model.shared", - ) - # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) - self.shared.load_weight_prefix = "model.shared" - - self.encoder = TFMBartEncoder(config, self.shared, name="encoder") - self.decoder = TFMBartDecoder(config, self.shared, name="decoder") - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.embed_tokens = self.shared - self.decoder.embed_tokens = self.shared - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - decoder_position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs, - ) -> TFSeq2SeqModelOutput | tf.Tensor: - if decoder_input_ids is None and decoder_inputs_embeds is None: - use_cache = False - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if decoder_input_ids is None and input_ids is not None: - decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id) - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): - encoder_outputs = TFBaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not return_dict and not isinstance(encoder_outputs, tuple): - encoder_outputs = encoder_outputs.to_tuple() - - decoder_outputs = self.decoder( - decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return TFSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - # The shared/tied weights expect to be in the model base namespace - # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than - # the current one. - with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): - self.shared.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - "The bare MBART Model outputting raw hidden-states without any specific head on top.", - MBART_START_DOCSTRING, -) -class TFMBartModel(TFMBartPreTrainedModel): - def __init__(self, config: MBartConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.model = TFMBartMainLayer(config, name="model") - - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSeq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - decoder_position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs, - ) -> TFSeq2SeqModelOutput | tuple[tf.Tensor]: - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqModelOutput( - last_hidden_state=output.last_hidden_state, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer -class BiasLayer(keras.layers.Layer): - """ - Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, - so all weights have to be registered in a layer. - """ - - def __init__(self, shape, initializer, trainable, name, **kwargs): - super().__init__(name=name, **kwargs) - # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of - # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: - # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 - self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) - - def call(self, x): - return x + self.bias - - -@add_start_docstrings( - "The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.", - MBART_START_DOCSTRING, -) -class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageModelingLoss): - _keys_to_ignore_on_load_unexpected = [ - r"model.encoder.embed_tokens.weight", - r"model.decoder.embed_tokens.weight", - ] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.model = TFMBartMainLayer(config, name="model") - self.use_cache = config.use_cache - # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False - ) - - def get_decoder(self): - return self.model.decoder - - def get_encoder(self): - return self.model.encoder - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - def get_bias(self): - return {"final_logits_bias": self.bias_layer.bias} - - def set_bias(self, value): - # Replaces the existing layers containing bias for correct (de)serialization. - vocab_size = value["final_logits_bias"].shape[-1] - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False - ) - self.bias_layer.bias.assign(value["final_logits_bias"]) - - @unpack_inputs - @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - @add_end_docstrings(MBART_GENERATION_EXAMPLE) - def call( - self, - input_ids: TFModelInputType = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - decoder_position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - encoder_outputs: TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSeq2SeqLMOutput | tuple[tf.Tensor]: - """ - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - """ - - if labels is not None: - labels = tf.where( - labels == self.config.pad_token_id, - tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), - labels, - ) - use_cache = False - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) - lm_logits = self.bias_layer(lm_logits) - masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - return TFSeq2SeqLMOutput( - loss=masked_lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, # index 1 of d outputs - decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs - decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs - cross_attentions=outputs.cross_attentions, # index 4 of d outputs - encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs - encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out - encoder_attentions=outputs.encoder_attentions, # 2 of e out - ) - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - if decoder_attention_mask is not None: # xla - decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] - else: # no xla + no past_key_values - decoder_position_ids = tf.range(decoder_input_ids.shape[1]) - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "decoder_position_ids": decoder_position_ids, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - - def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "bias_layer", None) is not None: - with tf.name_scope(self.bias_layer.name): - self.bias_layer.build(None) - - -__all__ = ["TFMBartForConditionalGeneration", "TFMBartModel", "TFMBartPreTrainedModel"] diff --git a/src/transformers/models/mistral/modeling_flax_mistral.py b/src/transformers/models/mistral/modeling_flax_mistral.py deleted file mode 100644 index 2c084ee114d7..000000000000 --- a/src/transformers/models/mistral/modeling_flax_mistral.py +++ /dev/null @@ -1,744 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Mistral AI and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax Mistral model.""" - -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPast, - FlaxCausalLMOutput, - FlaxCausalLMOutputWithCrossAttentions, -) -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, logging -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward -from .configuration_mistral import MistralConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "MistralConfig" -_REAL_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" -_CHECKPOINT_FOR_DOC = "ksmcg/Mistral-tiny" - -MISTRAL_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`MistralConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or - `jax.numpy.bfloat16`. - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -MISTRAL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRMSNorm with Llama->Mistral -class FlaxMistralRMSNorm(nn.Module): - config: MistralConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.epsilon = self.config.rms_norm_eps - self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size) - - def __call__(self, hidden_states): - variance = jnp.asarray(hidden_states, dtype=jnp.float32) - variance = jnp.power(variance, 2) - variance = variance.mean(-1, keepdims=True) - # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt` - hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon) - - return self.weight * jnp.asarray(hidden_states, dtype=self.dtype) - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Mistral -class FlaxMistralRotaryEmbedding(nn.Module): - config: MistralConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - head_dim = self.config.hidden_size // self.config.num_attention_heads - self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim) - - def __call__(self, key, query, position_ids): - sincos = self.sincos[position_ids] - sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1) - - key = apply_rotary_pos_emb(key, sin_pos, cos_pos) - query = apply_rotary_pos_emb(query, sin_pos, cos_pos) - - key = jnp.asarray(key, dtype=self.dtype) - query = jnp.asarray(query, dtype=self.dtype) - - return key, query - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP with Llama->Mistral -class FlaxMistralMLP(nn.Module): - config: MistralConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - embed_dim = self.config.hidden_size - inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim - - kernel_init = jax.nn.initializers.normal(self.config.initializer_range) - self.act = ACT2FN[self.config.hidden_act] - - self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) - self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) - self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) - - def __call__(self, hidden_states): - up_proj_states = self.up_proj(hidden_states) - gate_states = self.act(self.gate_proj(hidden_states)) - - hidden_states = self.down_proj(up_proj_states * gate_states) - return hidden_states - - -# Copied from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(tensor, sin_pos, cos_pos): - return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos) - - -# Copied from transformers.models.llama.modeling_flax_llama.create_sinusoidal_positions -def create_sinusoidal_positions(num_pos, dim): - inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim)) - freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") - - emb = np.concatenate((freqs, freqs), axis=-1) - out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) - return jnp.array(out[:, :, :num_pos]) - - -# Copied from transformers.models.llama.modeling_flax_llama.rotate_half -def rotate_half(tensor): - """Rotates half the hidden dims of the input.""" - rotate_half_tensor = jnp.concatenate( - (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1 - ) - return rotate_half_tensor - - -class FlaxMistralAttention(nn.Module): - config: MistralConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - config = self.config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 - self.rope_theta = config.rope_theta - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Dense(self.num_heads * self.head_dim, use_bias=False, dtype=self.dtype) - self.k_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype) - self.v_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype) - self.o_proj = nn.Dense(self.hidden_size, use_bias=False, dtype=self.dtype) - causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") - self.causal_mask = jnp.triu(causal_mask, k=-(config.sliding_window or 0)) - self.rotary_emb = FlaxMistralRotaryEmbedding(self.config, dtype=self.dtype) - - def _split_heads(self, hidden_states, num_heads): - return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) - - @nn.compact - # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - deterministic: bool = True, - output_attentions: bool = False, - init_cache: bool = False, - ) -> tuple[jnp.ndarray, jnp.ndarray]: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states, self.num_heads) - key_states = self._split_heads(key_states, self.num_key_value_heads) - value_states = self._split_heads(value_states, self.num_key_value_heads) - - key_states, query_states = self.rotary_emb(key_states, query_states, position_ids) - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - - batch_size = hidden_states.shape[0] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - - if self.has_variable("cache", "cached_key") or init_cache: - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - key_states = jnp.repeat(key_states, self.num_key_value_groups, axis=2) - value_states = jnp.repeat(value_states, self.num_key_value_groups, axis=2) - - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - - # usual dot product attention - attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - deterministic=deterministic, - dropout_rate=self.config.attention_dropout, - dtype=attention_dtype, - ) - - if self.attention_softmax_in_fp32: - attn_weights = attn_weights.astype(self.dtype) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = self._merge_heads(attn_output) - attn_output = self.o_proj(attn_output) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Mistral -class FlaxMistralDecoderLayer(nn.Module): - config: MistralConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.input_layernorm = FlaxMistralRMSNorm(self.config, dtype=self.dtype) - self.self_attn = FlaxMistralAttention(self.config, dtype=self.dtype) - self.post_attention_layernorm = FlaxMistralRMSNorm(self.config, dtype=self.dtype) - self.mlp = FlaxMistralMLP(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - ): - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - outputs = self.self_attn( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - # residual connection - attn_output = outputs[0] - hidden_states = residual + attn_output - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - # residual connection - hidden_states = residual + hidden_states - - return (hidden_states,) + outputs[1:] - - -# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Mistral, GPT_NEO->MISTRAL, transformer->model -class FlaxMistralPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = MistralConfig - base_model_prefix = "model" - module_class: nn.Module = None - - def __init__( - self, - config: MistralConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length)) - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - params: Optional[dict] = None, - past_key_values: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - batch_size, sequence_length = input_ids.shape - - if position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxMistralAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - not train, - False, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection with Llama->Mistral -class FlaxMistralLayerCollection(nn.Module): - config: MistralConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.blocks = [ - FlaxMistralDecoderLayer(self.config, dtype=self.dtype, name=str(i)) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = False, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for block in self.blocks: - if output_hidden_states: - all_hidden_states += (hidden_states,) - layer_outputs = block( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - # this contains possible `None` values - `FlaxMistralModule` will filter them out - outputs = (hidden_states, all_hidden_states, all_attentions) - - return outputs - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModule with Llama->Mistral -class FlaxMistralModule(nn.Module): - config: MistralConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.hidden_size = self.config.hidden_size - embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) - self.embed_tokens = nn.Embed( - self.config.vocab_size, - self.hidden_size, - embedding_init=embedding_init, - dtype=self.dtype, - ) - self.layers = FlaxMistralLayerCollection(self.config, dtype=self.dtype) - self.norm = FlaxMistralRMSNorm(self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - deterministic=True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - input_embeds = self.embed_tokens(input_ids.astype("i4")) - - outputs = self.layers( - input_embeds, - position_ids=position_ids, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.norm(hidden_states) - - if output_hidden_states: - all_hidden_states = outputs[1] + (hidden_states,) - outputs = (hidden_states, all_hidden_states) + outputs[2:] - else: - outputs = (hidden_states,) + outputs[1:] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=outputs[1], - attentions=outputs[-1], - ) - - -@add_start_docstrings( - "The bare Mistral Model transformer outputting raw hidden-states without any specific head on top.", - MISTRAL_START_DOCSTRING, -) -class FlaxMistralModel(FlaxMistralPreTrainedModel): - module_class = FlaxMistralModule - - -append_call_sample_docstring( - FlaxMistralModel, - _CHECKPOINT_FOR_DOC, - FlaxBaseModelOutputWithPast, - _CONFIG_FOR_DOC, - real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, -) - - -# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule with Llama->Mistral -class FlaxMistralForCausalLMModule(nn.Module): - config: MistralConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.model = FlaxMistralModule(self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - - def __call__( - self, - input_ids, - attention_mask=None, - position_ids=None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - outputs = self.model( - input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - - -@add_start_docstrings( - """ - The Mistral Model transformer with a language modeling head (linear layer) on top. - """, - MISTRAL_START_DOCSTRING, -) - -# Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Mistral -class FlaxMistralForCausalLM(FlaxMistralPreTrainedModel): - module_class = FlaxMistralForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since Mistral uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxMistralForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutputWithCrossAttentions, - _CONFIG_FOR_DOC, - real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, -) - -__all__ = ["FlaxMistralForCausalLM", "FlaxMistralModel", "FlaxMistralPreTrainedModel"] diff --git a/src/transformers/models/mistral/modeling_tf_mistral.py b/src/transformers/models/mistral/modeling_tf_mistral.py deleted file mode 100644 index d3ca7d13b6a8..000000000000 --- a/src/transformers/models/mistral/modeling_tf_mistral.py +++ /dev/null @@ -1,1016 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Mistral AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 Mistral model.""" - -import math -import warnings -from typing import Optional, Union - -import tensorflow as tf - -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPast, - TFCausalLMOutputWithPast, - TFSequenceClassifierOutputWithPast, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - get_tf_activation, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_mistral import MistralConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "MistralConfig" - - -def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0): - """ - Make causal mask used for bi-directional self-attention, supporting both static and dynamic shapes. - """ - bsz, tgt_len = input_ids_shape - - # Create a matrix where only the lower triangle and diagonal are filled with zeros (causal mask) - mask = tf.fill((tgt_len, tgt_len), tf.dtypes.as_dtype(dtype).min) - mask_cond = tf.range(tgt_len) - mask = tf.where(mask_cond[:, None] >= mask_cond[None, :], 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1) - - if bsz is None: - # When batch size is dynamic, expand and tile - # so we can compile a functional model - mask = tf.expand_dims(mask, 0) - mask = tf.expand_dims(mask, 0) # shape: (1, 1, tgt_len, tgt_len + past_key_values_length) - mask = tf.tile(mask, [bsz, 1, 1, 1]) - else: - # When batch size is static, directly use broadcast_to - mask = tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)) - - return mask - - -def _expand_mask(mask, dtype, tgt_len=None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = shape_list(mask) - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = tf.expand_dims(tf.expand_dims(mask, 1), 1) - expanded_mask = tf.broadcast_to(expanded_mask, [bsz, 1, tgt_len, src_len]) - - inverted_mask = 1.0 - tf.cast(expanded_mask, dtype) - - return tf.where( - tf.cast(inverted_mask, bool), tf.fill(dims=shape_list(inverted_mask), value=tf.float32.min), inverted_mask - ) - - -class TFMistralRMSNorm(keras.layers.Layer): - def __init__(self, hidden_size, eps=1e-6, **kwargs): - """ - TFMistralRMSNorm is equivalent to T5LayerNorm - """ - super().__init__(**kwargs) - self.hidden_size = hidden_size - self.variance_epsilon = eps - - def build(self, input_shape=None): - self.weight = self.add_weight( - name="weight", - shape=self.hidden_size, - initializer="ones", - ) - if self.built: - return - self.built = True - - def call(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = tf.cast(hidden_states, tf.float32) - variance = tf.reduce_mean(tf.square(hidden_states), axis=-1, keepdims=True) - hidden_states = tf.divide(hidden_states, tf.sqrt(variance + self.variance_epsilon)) - return self.weight * tf.cast(hidden_states, input_dtype) - - -# Verification: https://colab.research.google.com/gist/ariG23498/f8d8131b795a131b93d99e70ee93c192/scratchpad.ipynb -class TFMistralRotaryEmbedding(keras.layers.Layer): - def __init__(self, dim, max_position_embeddings=2048, base=10000, **kwargs): - super().__init__(**kwargs) - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.inv_freq = 1.0 / (self.base ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim)) - - def call(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - t = tf.cast(tf.range(seq_len, dtype=tf.int64), self.inv_freq.dtype) - freqs = tf.einsum("i,j->ij", t, self.inv_freq) - emb = tf.concat([freqs, freqs], axis=-1) - cos_values = tf.cast(tf.cos(emb), x.dtype) - sin_values = tf.cast(tf.sin(emb), x.dtype) - - cos_values = cos_values[:seq_len] - cos_values = tf.cast(cos_values, dtype=x.dtype) - sin_values = sin_values[:seq_len] - sin_values = tf.cast(sin_values, dtype=x.dtype) - return (cos_values, sin_values) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - mid_length = shape_list(x)[-1] // 2 - x1 = x[..., :mid_length] - x2 = x[..., mid_length:] - return tf.concat([-x2, x1], axis=-1) - - -# Verification: https://colab.research.google.com/gist/ariG23498/bb8474baeb33f4ae6ed7d77da5f7e7a4/scratchpad.ipynb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`tf.Tensor`): The query tensor. - k (`tf.Tensor`): The key tensor. - cos (`tf.Tensor`): The cosine part of the rotary embedding. - sin (`tf.Tensor`): The sine part of the rotary embedding. - position_ids (`tf.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(tf.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = tf.expand_dims(tf.gather(cos, position_ids), unsqueeze_dim) - sin = tf.expand_dims(tf.gather(sin, position_ids), unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class TFMistralMLP(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = keras.layers.Dense(self.intermediate_size, use_bias=False, name="gate_proj") - self.up_proj = keras.layers.Dense(self.intermediate_size, use_bias=False, name="up_proj") - self.down_proj = keras.layers.Dense(self.hidden_size, use_bias=False, name="down_proj") - self.act_fn = get_tf_activation(config.hidden_act) - - def call(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "gate_proj", None) is not None: - with tf.name_scope(self.gate_proj.name): - self.gate_proj.build((self.hidden_size,)) - if getattr(self, "up_proj", None) is not None: - with tf.name_scope(self.up_proj.name): - self.up_proj.build((self.hidden_size,)) - if getattr(self, "down_proj", None) is not None: - with tf.name_scope(self.down_proj.name): - self.down_proj.build((self.intermediate_size,)) - - -# Verification: https://colab.research.google.com/gist/ariG23498/556d443d491966763ce2e7eee336efed/scratchpad.ipynb -def repeat_kv(hidden_states: tf.Tensor, n_rep: int) -> tf.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = shape_list(hidden_states) - if n_rep == 1: - return hidden_states - hidden_states = tf.expand_dims(hidden_states, 2) - hidden_states = tf.repeat(hidden_states, repeats=n_rep, axis=2) - return tf.reshape(hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim)) - - -class TFMistralAttention(keras.layers.Layer): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = keras.layers.Dense(self.num_heads * self.head_dim, use_bias=False, name="q_proj") - self.k_proj = keras.layers.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, name="k_proj") - self.v_proj = keras.layers.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, name="v_proj") - self.o_proj = keras.layers.Dense(self.hidden_size, use_bias=False, name="o_proj") - - self.rotary_emb = TFMistralRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - name="rotary_emb", - ) - self.dropout = keras.layers.Dropout(rate=self.attention_dropout) - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - tensor = tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)) - tensor = tf.transpose(tensor, perm=(0, 2, 1, 3)) - return tensor - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: Optional[tf.Tensor] = None, - position_ids: Optional[tf.Tensor] = None, - past_key_value: Optional[tuple[tf.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - training=None, - **kwargs, - ) -> tuple[tf.Tensor, Optional[tf.Tensor], Optional[tuple[tf.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = shape_list(hidden_states) - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = tf.transpose( - tf.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim)), perm=(0, 2, 1, 3) - ) - key_states = tf.transpose( - tf.reshape(key_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)), perm=(0, 2, 1, 3) - ) - value_states = tf.transpose( - tf.reshape(value_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)), perm=(0, 2, 1, 3) - ) - - kv_seq_len = shape_list(key_states)[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb( - x=value_states, - seq_len=kv_seq_len, - ) - query_states, key_states = apply_rotary_pos_emb( - q=query_states, - k=key_states, - cos=cos, - sin=sin, - position_ids=position_ids, - ) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = stable_softmax(attn_weights, axis=-1) - attn_weights = tf.cast(attn_weights, query_states.dtype) - attn_weights = self.dropout( - attn_weights, - training=training, - ) - attn_output = tf.matmul(attn_weights, value_states) - - attn_output = tf.transpose(attn_output, perm=(0, 2, 1, 3)) - attn_output = tf.reshape(attn_output, (bsz, q_len, self.hidden_size)) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build((self.hidden_size,)) - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build((self.hidden_size,)) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build((self.hidden_size,)) - if getattr(self, "o_proj", None) is not None: - with tf.name_scope(self.o_proj.name): - self.o_proj.build((self.num_heads * self.head_dim,)) - - -class TFMistralDecoderLayer(keras.layers.Layer): - def __init__(self, config: MistralConfig, layer_idx: int, **kwargs): - super().__init__(**kwargs) - self.hidden_size = config.hidden_size - - self.self_attn = TFMistralAttention(config, layer_idx, name="self_attn") - - self.mlp = TFMistralMLP(config, name="mlp") - self.input_layernorm = TFMistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm") - self.post_attention_layernorm = TFMistralRMSNorm( - config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm" - ) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: Optional[tf.Tensor] = None, - position_ids: Optional[tf.Tensor] = None, - past_key_value: Optional[tuple[tf.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> tuple[tf.Tensor, Optional[tuple[tf.Tensor, tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "input_layernorm", None) is not None: - with tf.name_scope(self.input_layernorm.name): - self.input_layernorm.build(None) - if getattr(self, "post_attention_layernorm", None) is not None: - with tf.name_scope(self.post_attention_layernorm.name): - self.post_attention_layernorm.build(None) - - -@keras_serializable -class TFMistralMainLayer(keras.layers.Layer): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] - - Args: - config: MistralConfig - """ - - config_class = MistralConfig - - def __init__(self, config: MistralConfig, **kwargs): - super().__init__(**kwargs) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.hidden_size = config.hidden_size - - # TF and PT Embedding check: https://colab.research.google.com/gist/ariG23498/2b9826818875c9c4968c79cb19f55f2c/scratchpad.ipynb - self.embed_tokens = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.hidden_size, - name="embed_tokens", - ) - self.layers = [ - TFMistralDecoderLayer(config, layer_idx, name=f"layers.{layer_idx}") - for layer_idx in range(config.num_hidden_layers) - ] - self._attn_implementation = config._attn_implementation - self.norm = TFMistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="norm") - self.config = config - - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - # if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @unpack_inputs - def call( - self, - input_ids: Optional[tf.Tensor] = None, - attention_mask: Optional[tf.Tensor] = None, - position_ids: Optional[tf.Tensor] = None, - past_key_values: Optional[list[tf.Tensor]] = None, - inputs_embeds: Optional[tf.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, TFBaseModelOutputWithPast]: - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = shape_list(input_ids) - elif inputs_embeds is not None: - batch_size, seq_length, _ = shape_list(inputs_embeds) - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = shape_list(past_key_values[0][0])[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - position_ids = tf.range( - start=past_key_values_length, limit=seq_length + past_key_values_length, dtype=tf.int64 - ) - position_ids = tf.reshape(tf.expand_dims(position_ids, 0), (-1, seq_length)) - - else: - position_ids = tf.cast(tf.reshape(position_ids, (-1, seq_length)), tf.int64) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is None: - attention_mask = tf.ones((batch_size, seq_length_with_past), dtype=tf.bool) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return TFBaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_tokens", None) is not None: - with tf.name_scope(self.embed_tokens.name): - self.embed_tokens.build(None) - if getattr(self, "norm", None) is not None: - with tf.name_scope(self.norm.name): - self.norm.build(None) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -MISTRAL_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `model` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`MistralConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Mistral Model outputting raw hidden-states without any specific head on top.", - MISTRAL_START_DOCSTRING, -) -class TFMistralPreTrainedModel(TFPreTrainedModel): - config_class = MistralConfig - base_model_prefix = "model" - - -MISTRAL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(tf.Tensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - One formats is allowed: - - Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare Mistral Model outputting raw hidden-states without any specific head on top.", - MISTRAL_START_DOCSTRING, -) -class TFMistralModel(TFMistralPreTrainedModel): - def __init__(self, config: MistralConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.model = TFMistralMainLayer(config, name="model") - - @unpack_inputs - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) - def call( - self, - input_ids: Optional[tf.Tensor] = None, - attention_mask: Optional[tf.Tensor] = None, - position_ids: Optional[tf.Tensor] = None, - past_key_values: Optional[list[tf.Tensor]] = None, - inputs_embeds: Optional[tf.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, TFBaseModelOutputWithPast]: - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -class TFMistralForCausalLM(TFMistralPreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.model = TFMistralMainLayer(config, name="model") - self.vocab_size = config.vocab_size - self.lm_head = keras.layers.Dense( - config.vocab_size, - use_bias=False, - kernel_initializer=get_initializer(config.initializer_range), - name="lm_head", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def call( - self, - input_ids: Optional[tf.Tensor] = None, - attention_mask: Optional[tf.Tensor] = None, - position_ids: Optional[tf.Tensor] = None, - past_key_values: Optional[list[tf.Tensor]] = None, - inputs_embeds: Optional[tf.Tensor] = None, - labels: Optional[tf.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, TFCausalLMOutputWithPast]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` - or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - """ - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = tf.cast(logits, tf.float32) - - loss = None - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels, shifted_logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - # Omit tokens covered by past_key_values - if past_key_values: - input_ids = tf.expand_dims(input_ids[:, -1], -1) - - position_ids = kwargs.get("position_ids") - if attention_mask is not None and position_ids is None: - position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) - if past_key_values: - position_ids = tf.expand_dims(position_ids[:, -1], -1) - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - } - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build((self.config.hidden_size,)) - - -@add_start_docstrings( - """ - The Mistral Model transformer with a sequence classification head on top (linear layer). - - [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - MISTRAL_START_DOCSTRING, -) -class TFMistralForSequenceClassification(TFMistralPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.model = TFMistralMainLayer(config, name="model") - self.score = keras.layers.Dense( - self.num_labels, - use_bias=False, - kernel_initializer=get_initializer(config.initializer_range), - name="score", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def call( - self, - input_ids: Optional[tf.Tensor] = None, - attention_mask: Optional[tf.Tensor] = None, - position_ids: Optional[tf.Tensor] = None, - past_key_values: Optional[list[tf.Tensor]] = None, - inputs_embeds: Optional[tf.Tensor] = None, - labels: Optional[tf.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, TFSequenceClassifierOutputWithPast]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - """ - - transformer_outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - logits_shape = shape_list(logits) - batch_size = logits_shape[0] - - if self.config.pad_token_id is None: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - else: - if input_ids is not None: - token_indices = tf.range(shape_list(input_ids)[-1]) - non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype) - last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1) - else: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - loss = None - - pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1) - - if labels is not None: - if self.config.pad_token_id is None and logits_shape[0] != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - - loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels])) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "score", None) is not None: - with tf.name_scope(self.score.name): - self.score.build((self.config.hidden_size,)) - - -__all__ = ["TFMistralModel", "TFMistralForCausalLM", "TFMistralForSequenceClassification", "TFMistralPreTrainedModel"] diff --git a/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index 022a9d036cdb..000000000000 --- a/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -import torch - -from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path): - # Initialise PyTorch model - config = MobileBertConfig.from_json_file(mobilebert_config_file) - print(f"Building PyTorch model from configuration: {config}") - model = MobileBertForPreTraining(config) - # Load weights from tf checkpoint - model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path) - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - torch.save(model.state_dict(), pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--mobilebert_config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained MobileBERT model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.mobilebert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py deleted file mode 100644 index e4d148aa76c5..000000000000 --- a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py +++ /dev/null @@ -1,1979 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 MobileBERT model.""" - -from __future__ import annotations - -import warnings -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFNextSentencePredictorOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFNextSentencePredictionLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_mobilebert import MobileBertConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "google/mobilebert-uncased" -_CONFIG_FOR_DOC = "MobileBertConfig" - -# TokenClassification docstring -_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "vumichien/mobilebert-finetuned-ner" -_TOKEN_CLASS_EXPECTED_OUTPUT = "['I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC']" -_TOKEN_CLASS_EXPECTED_LOSS = 0.03 - -# QuestionAnswering docstring -_CHECKPOINT_FOR_QA = "vumichien/mobilebert-uncased-squad-v2" -_QA_EXPECTED_OUTPUT = "'a nice puppet'" -_QA_EXPECTED_LOSS = 3.98 -_QA_TARGET_START_INDEX = 12 -_QA_TARGET_END_INDEX = 13 - -# SequenceClassification docstring -_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "vumichien/emo-mobilebert" -_SEQ_CLASS_EXPECTED_OUTPUT = "'others'" -_SEQ_CLASS_EXPECTED_LOSS = "4.72" - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPreTrainingLoss -class TFMobileBertPreTrainingLoss: - """ - Loss function suitable for BERT-like pretraining, that is, the task of pretraining a language model by combining - NSP + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss - computation. - """ - - def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor: - loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) - - # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway - unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0]) - # make sure only labels that are not equal to -100 - # are taken into account for the loss computation - lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype) - masked_lm_losses = unmasked_lm_losses * lm_loss_mask - reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask) - - # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway - unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels["next_sentence_label"]), y_pred=logits[1]) - ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype) - masked_ns_loss = unmasked_ns_loss * ns_loss_mask - - reduced_masked_ns_loss = tf.reduce_sum(masked_ns_loss) / tf.reduce_sum(ns_loss_mask) - - return tf.reshape(reduced_masked_lm_loss + reduced_masked_ns_loss, (1,)) - - -class TFMobileBertIntermediate(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense(config.intermediate_size, name="dense") - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.true_hidden_size]) - - -class TFLayerNorm(keras.layers.LayerNormalization): - def __init__(self, feat_size, *args, **kwargs): - self.feat_size = feat_size - super().__init__(*args, **kwargs) - - def build(self, input_shape=None): - super().build([None, None, self.feat_size]) - - -class TFNoNorm(keras.layers.Layer): - def __init__(self, feat_size, epsilon=None, **kwargs): - super().__init__(**kwargs) - self.feat_size = feat_size - - def build(self, input_shape): - self.bias = self.add_weight("bias", shape=[self.feat_size], initializer="zeros") - self.weight = self.add_weight("weight", shape=[self.feat_size], initializer="ones") - super().build(input_shape) - - def call(self, inputs: tf.Tensor): - return inputs * self.weight + self.bias - - -NORM2FN = {"layer_norm": TFLayerNorm, "no_norm": TFNoNorm} - - -class TFMobileBertEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.trigram_input = config.trigram_input - self.embedding_size = config.embedding_size - self.config = config - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.embedding_transformation = keras.layers.Dense(config.hidden_size, name="embedding_transformation") - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = NORM2FN[config.normalization_type]( - config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.embedded_input_size = self.embedding_size * (3 if self.trigram_input else 1) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.embedding_size], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.hidden_size], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "embedding_transformation", None) is not None: - with tf.name_scope(self.embedding_transformation.name): - self.embedding_transformation.build([None, None, self.embedded_input_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build(None) - - def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False): - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if self.trigram_input: - # From the paper MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited - # Devices (https://huggingface.co/papers/2004.02984) - # - # The embedding table in BERT models accounts for a substantial proportion of model size. To compress - # the embedding layer, we reduce the embedding dimension to 128 in MobileBERT. - # Then, we apply a 1D convolution with kernel size 3 on the raw token embedding to produce a 512 - # dimensional output. - inputs_embeds = tf.concat( - [ - tf.pad(inputs_embeds[:, 1:], ((0, 0), (0, 1), (0, 0))), - inputs_embeds, - tf.pad(inputs_embeds[:, :-1], ((0, 0), (1, 0), (0, 0))), - ], - axis=2, - ) - - if self.trigram_input or self.embedding_size != self.hidden_size: - inputs_embeds = self.embedding_transformation(inputs_embeds) - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -class TFMobileBertSelfAttention(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads}" - ) - - self.num_attention_heads = config.num_attention_heads - self.output_attentions = config.output_attentions - assert config.hidden_size % config.num_attention_heads == 0 - self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.config = config - - def transpose_for_scores(self, x, batch_size): - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) - return tf.transpose(x, perm=[0, 2, 1, 3]) - - def call( - self, query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=False - ): - batch_size = shape_list(attention_mask)[0] - mixed_query_layer = self.query(query_tensor) - mixed_key_layer = self.key(key_tensor) - mixed_value_layer = self.value(value_tensor) - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = tf.matmul( - query_layer, key_layer, transpose_b=True - ) # (batch size, num_heads, seq_len_q, seq_len_k) - dk = tf.cast(shape_list(key_layer)[-1], dtype=attention_scores.dtype) # scale attention_scores - attention_scores = attention_scores / tf.math.sqrt(dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFMobileBertModel call() function) - attention_mask = tf.cast(attention_mask, dtype=attention_scores.dtype) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = tf.matmul(attention_probs, value_layer) - - context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) - context_layer = tf.reshape( - context_layer, (batch_size, -1, self.all_head_size) - ) # (batch_size, seq_len_q, all_head_size) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.true_hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.true_hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build( - [ - None, - None, - self.config.true_hidden_size - if self.config.use_bottleneck_attention - else self.config.hidden_size, - ] - ) - - -class TFMobileBertSelfOutput(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.use_bottleneck = config.use_bottleneck - self.dense = keras.layers.Dense( - config.true_hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = NORM2FN[config.normalization_type]( - config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm" - ) - if not self.use_bottleneck: - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states, residual_tensor, training=False): - hidden_states = self.dense(hidden_states) - if not self.use_bottleneck: - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.LayerNorm(hidden_states + residual_tensor) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.true_hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build(None) - - -class TFMobileBertAttention(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.self = TFMobileBertSelfAttention(config, name="self") - self.mobilebert_output = TFMobileBertSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - query_tensor, - key_tensor, - value_tensor, - layer_input, - attention_mask, - head_mask, - output_attentions, - training=False, - ): - self_outputs = self.self( - query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=training - ) - - attention_output = self.mobilebert_output(self_outputs[0], layer_input, training=training) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self", None) is not None: - with tf.name_scope(self.self.name): - self.self.build(None) - if getattr(self, "mobilebert_output", None) is not None: - with tf.name_scope(self.mobilebert_output.name): - self.mobilebert_output.build(None) - - -class TFOutputBottleneck(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense(config.hidden_size, name="dense") - self.LayerNorm = NORM2FN[config.normalization_type]( - config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm" - ) - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states, residual_tensor, training=False): - layer_outputs = self.dense(hidden_states) - layer_outputs = self.dropout(layer_outputs, training=training) - layer_outputs = self.LayerNorm(layer_outputs + residual_tensor) - return layer_outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.true_hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build(None) - - -class TFMobileBertOutput(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.use_bottleneck = config.use_bottleneck - self.dense = keras.layers.Dense( - config.true_hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = NORM2FN[config.normalization_type]( - config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm" - ) - if not self.use_bottleneck: - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - else: - self.bottleneck = TFOutputBottleneck(config, name="bottleneck") - self.config = config - - def call(self, hidden_states, residual_tensor_1, residual_tensor_2, training=False): - hidden_states = self.dense(hidden_states) - if not self.use_bottleneck: - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.LayerNorm(hidden_states + residual_tensor_1) - else: - hidden_states = self.LayerNorm(hidden_states + residual_tensor_1) - hidden_states = self.bottleneck(hidden_states, residual_tensor_2) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build(None) - if getattr(self, "bottleneck", None) is not None: - with tf.name_scope(self.bottleneck.name): - self.bottleneck.build(None) - - -class TFBottleneckLayer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense(config.intra_bottleneck_size, name="dense") - self.LayerNorm = NORM2FN[config.normalization_type]( - config.intra_bottleneck_size, epsilon=config.layer_norm_eps, name="LayerNorm" - ) - self.config = config - - def call(self, inputs): - hidden_states = self.dense(inputs) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build(None) - - -class TFBottleneck(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.key_query_shared_bottleneck = config.key_query_shared_bottleneck - self.use_bottleneck_attention = config.use_bottleneck_attention - self.bottleneck_input = TFBottleneckLayer(config, name="input") - if self.key_query_shared_bottleneck: - self.attention = TFBottleneckLayer(config, name="attention") - - def call(self, hidden_states): - # This method can return three different tuples of values. These different values make use of bottlenecks, - # which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory - # usage. These linear layer have weights that are learned during training. - # - # If `config.use_bottleneck_attention`, it will return the result of the bottleneck layer four times for the - # key, query, value, and "layer input" to be used by the attention layer. - # This bottleneck is used to project the hidden. This last layer input will be used as a residual tensor - # in the attention self output, after the attention scores have been computed. - # - # If not `config.use_bottleneck_attention` and `config.key_query_shared_bottleneck`, this will return - # four values, three of which have been passed through a bottleneck: the query and key, passed through the same - # bottleneck, and the residual layer to be applied in the attention self output, through another bottleneck. - # - # Finally, in the last case, the values for the query, key and values are the hidden states without bottleneck, - # and the residual layer will be this value passed through a bottleneck. - - bottlenecked_hidden_states = self.bottleneck_input(hidden_states) - if self.use_bottleneck_attention: - return (bottlenecked_hidden_states,) * 4 - elif self.key_query_shared_bottleneck: - shared_attention_input = self.attention(hidden_states) - return (shared_attention_input, shared_attention_input, hidden_states, bottlenecked_hidden_states) - else: - return (hidden_states, hidden_states, hidden_states, bottlenecked_hidden_states) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "bottleneck_input", None) is not None: - with tf.name_scope(self.bottleneck_input.name): - self.bottleneck_input.build(None) - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - - -class TFFFNOutput(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense(config.true_hidden_size, name="dense") - self.LayerNorm = NORM2FN[config.normalization_type]( - config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm" - ) - self.config = config - - def call(self, hidden_states, residual_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.LayerNorm(hidden_states + residual_tensor) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build(None) - - -class TFFFNLayer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.intermediate = TFMobileBertIntermediate(config, name="intermediate") - self.mobilebert_output = TFFFNOutput(config, name="output") - - def call(self, hidden_states): - intermediate_output = self.intermediate(hidden_states) - layer_outputs = self.mobilebert_output(intermediate_output, hidden_states) - return layer_outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "mobilebert_output", None) is not None: - with tf.name_scope(self.mobilebert_output.name): - self.mobilebert_output.build(None) - - -class TFMobileBertLayer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.use_bottleneck = config.use_bottleneck - self.num_feedforward_networks = config.num_feedforward_networks - self.attention = TFMobileBertAttention(config, name="attention") - self.intermediate = TFMobileBertIntermediate(config, name="intermediate") - self.mobilebert_output = TFMobileBertOutput(config, name="output") - - if self.use_bottleneck: - self.bottleneck = TFBottleneck(config, name="bottleneck") - if config.num_feedforward_networks > 1: - self.ffn = [TFFFNLayer(config, name=f"ffn.{i}") for i in range(config.num_feedforward_networks - 1)] - - def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False): - if self.use_bottleneck: - query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states) - else: - query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4 - - attention_outputs = self.attention( - query_tensor, - key_tensor, - value_tensor, - layer_input, - attention_mask, - head_mask, - output_attentions, - training=training, - ) - - attention_output = attention_outputs[0] - s = (attention_output,) - - if self.num_feedforward_networks != 1: - for i, ffn_module in enumerate(self.ffn): - attention_output = ffn_module(attention_output) - s += (attention_output,) - - intermediate_output = self.intermediate(attention_output) - layer_output = self.mobilebert_output(intermediate_output, attention_output, hidden_states, training=training) - - outputs = ( - (layer_output,) - + attention_outputs[1:] - + ( - tf.constant(0), - query_tensor, - key_tensor, - value_tensor, - layer_input, - attention_output, - intermediate_output, - ) - + s - ) # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "mobilebert_output", None) is not None: - with tf.name_scope(self.mobilebert_output.name): - self.mobilebert_output.build(None) - if getattr(self, "bottleneck", None) is not None: - with tf.name_scope(self.bottleneck.name): - self.bottleneck.build(None) - if getattr(self, "ffn", None) is not None: - for layer in self.ffn: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFMobileBertEncoder(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.layer = [TFMobileBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states, - attention_mask, - head_mask, - output_attentions, - output_hidden_states, - return_dict, - training=False, - ): - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, attention_mask, head_mask[i], output_attentions, training=training - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFMobileBertPooler(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.do_activate = config.classifier_activation - if self.do_activate: - self.dense = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - if not self.do_activate: - return first_token_tensor - else: - pooled_output = self.dense(first_token_tensor) - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFMobileBertPredictionHeadTransform(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - if isinstance(config.hidden_act, str): - self.transform_act_fn = get_tf_activation(config.hidden_act) - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build(None) - - -class TFMobileBertLMPredictionHead(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.transform = TFMobileBertPredictionHeadTransform(config, name="transform") - self.config = config - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - self.dense = self.add_weight( - shape=(self.config.hidden_size - self.config.embedding_size, self.config.vocab_size), - initializer="zeros", - trainable=True, - name="dense/weight", - ) - self.decoder = self.add_weight( - shape=(self.config.vocab_size, self.config.embedding_size), - initializer="zeros", - trainable=True, - name="decoder/weight", - ) - - if self.built: - return - self.built = True - if getattr(self, "transform", None) is not None: - with tf.name_scope(self.transform.name): - self.transform.build(None) - - def get_output_embeddings(self): - return self - - def set_output_embeddings(self, value): - self.decoder = value - self.config.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = tf.matmul(hidden_states, tf.concat([tf.transpose(self.decoder), self.dense], axis=0)) - hidden_states = hidden_states + self.bias - return hidden_states - - -class TFMobileBertMLMHead(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.predictions = TFMobileBertLMPredictionHead(config, name="predictions") - - def call(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - -@keras_serializable -class TFMobileBertMainLayer(keras.layers.Layer): - config_class = MobileBertConfig - - def __init__(self, config, add_pooling_layer=True, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.num_hidden_layers = config.num_hidden_layers - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - - self.embeddings = TFMobileBertEmbeddings(config, name="embeddings") - self.encoder = TFMobileBertEncoder(config, name="encoder") - self.pooler = TFMobileBertPooler(config, name="pooler") if add_pooling_layer else None - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, value): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(input_shape, 1) - - if token_type_ids is None: - token_type_ids = tf.fill(input_shape, 0) - - embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.num_hidden_layers - - encoder_outputs = self.encoder( - embedding_output, - extended_attention_mask, - head_mask, - output_attentions, - output_hidden_states, - return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -class TFMobileBertPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = MobileBertConfig - base_model_prefix = "mobilebert" - - -@dataclass -class TFMobileBertForPreTrainingOutput(ModelOutput): - """ - Output type of [`TFMobileBertForPreTraining`]. - - Args: - prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - seq_relationship_logits (`tf.Tensor` of shape `(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - prediction_logits: tf.Tensor | None = None - seq_relationship_logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -MOBILEBERT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`MobileBertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -MOBILEBERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare MobileBert Model transformer outputting raw hidden-states without any specific head on top.", - MOBILEBERT_START_DOCSTRING, -) -class TFMobileBertModel(TFMobileBertPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") - - @unpack_inputs - @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPooling, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutputWithPooling: - outputs = self.mobilebert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mobilebert", None) is not None: - with tf.name_scope(self.mobilebert.name): - self.mobilebert.build(None) - - -@add_start_docstrings( - """ - MobileBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a - `next sentence prediction (classification)` head. - """, - MOBILEBERT_START_DOCSTRING, -) -class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel, TFMobileBertPreTrainingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") - self.predictions = TFMobileBertMLMHead(config, name="predictions___cls") - self.seq_relationship = TFMobileBertOnlyNSPHead(config, name="seq_relationship___cls") - - def get_lm_head(self): - return self.predictions.predictions - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - next_sentence_label: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFMobileBertForPreTrainingOutput: - r""" - Return: - - Examples: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFMobileBertForPreTraining - - >>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased") - >>> model = TFMobileBertForPreTraining.from_pretrained("google/mobilebert-uncased") - >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 - >>> outputs = model(input_ids) - >>> prediction_scores, seq_relationship_scores = outputs[:2] - ```""" - outputs = self.mobilebert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output, pooled_output = outputs[:2] - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - - total_loss = None - if labels is not None and next_sentence_label is not None: - d_labels = {"labels": labels} - d_labels["next_sentence_label"] = next_sentence_label - total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score)) - - if not return_dict: - output = (prediction_scores, seq_relationship_score) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return TFMobileBertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mobilebert", None) is not None: - with tf.name_scope(self.mobilebert.name): - self.mobilebert.build(None) - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - if getattr(self, "seq_relationship", None) is not None: - with tf.name_scope(self.seq_relationship.name): - self.seq_relationship.build(None) - - def tf_to_pt_weight_rename(self, tf_weight): - if tf_weight == "cls.predictions.decoder.weight": - return tf_weight, "mobilebert.embeddings.word_embeddings.weight" - else: - return (tf_weight,) - - -@add_start_docstrings("""MobileBert Model with a `language modeling` head on top.""", MOBILEBERT_START_DOCSTRING) -class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"pooler", - r"seq_relationship___cls", - r"cls.seq_relationship", - ] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert") - self.predictions = TFMobileBertMLMHead(config, name="predictions___cls") - - def get_lm_head(self): - return self.predictions.predictions - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="'paris'", - expected_loss=0.57, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFMaskedLMOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels - """ - outputs = self.mobilebert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.predictions(sequence_output, training=training) - - loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mobilebert", None) is not None: - with tf.name_scope(self.mobilebert.name): - self.mobilebert.build(None) - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - def tf_to_pt_weight_rename(self, tf_weight): - if tf_weight == "cls.predictions.decoder.weight": - return tf_weight, "mobilebert.embeddings.word_embeddings.weight" - else: - return (tf_weight,) - - -class TFMobileBertOnlyNSPHead(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.seq_relationship = keras.layers.Dense(2, name="seq_relationship") - self.config = config - - def call(self, pooled_output): - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "seq_relationship", None) is not None: - with tf.name_scope(self.seq_relationship.name): - self.seq_relationship.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """MobileBert Model with a `next sentence prediction (classification)` head on top.""", - MOBILEBERT_START_DOCSTRING, -) -class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextSentencePredictionLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"predictions___cls", r"cls.predictions"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") - self.cls = TFMobileBertOnlyNSPHead(config, name="seq_relationship___cls") - - @unpack_inputs - @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - next_sentence_label: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFNextSentencePredictorOutput: - r""" - Return: - - Examples: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFMobileBertForNextSentencePrediction - - >>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased") - >>> model = TFMobileBertForNextSentencePrediction.from_pretrained("google/mobilebert-uncased") - - >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." - >>> encoding = tokenizer(prompt, next_sentence, return_tensors="tf") - - >>> logits = model(encoding["input_ids"], token_type_ids=encoding["token_type_ids"])[0] - ```""" - outputs = self.mobilebert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - seq_relationship_scores = self.cls(pooled_output) - - next_sentence_loss = ( - None - if next_sentence_label is None - else self.hf_compute_loss(labels=next_sentence_label, logits=seq_relationship_scores) - ) - - if not return_dict: - output = (seq_relationship_scores,) + outputs[2:] - return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output - - return TFNextSentencePredictorOutput( - loss=next_sentence_loss, - logits=seq_relationship_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mobilebert", None) is not None: - with tf.name_scope(self.mobilebert.name): - self.mobilebert.build(None) - if getattr(self, "cls", None) is not None: - with tf.name_scope(self.cls.name): - self.cls.build(None) - - -@add_start_docstrings( - """ - MobileBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - MOBILEBERT_START_DOCSTRING, -) -class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSequenceClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"predictions___cls", - r"seq_relationship___cls", - r"cls.predictions", - r"cls.seq_relationship", - ] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, - expected_loss=_SEQ_CLASS_EXPECTED_LOSS, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFSequenceClassifierOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.mobilebert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output, training=training) - logits = self.classifier(pooled_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mobilebert", None) is not None: - with tf.name_scope(self.mobilebert.name): - self.mobilebert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - MobileBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a - linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - MOBILEBERT_START_DOCSTRING, -) -class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"pooler", - r"predictions___cls", - r"seq_relationship___cls", - r"cls.predictions", - r"cls.seq_relationship", - ] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_QA, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - qa_target_start_index=_QA_TARGET_START_INDEX, - qa_target_end_index=_QA_TARGET_END_INDEX, - expected_output=_QA_EXPECTED_OUTPUT, - expected_loss=_QA_EXPECTED_LOSS, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFQuestionAnsweringModelOutput: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.mobilebert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - - loss = None - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions, "end_position": end_positions} - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mobilebert", None) is not None: - with tf.name_scope(self.mobilebert.name): - self.mobilebert.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - MobileBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and - a softmax) e.g. for RocStories/SWAG tasks. - """, - MOBILEBERT_START_DOCSTRING, -) -class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoiceLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"predictions___cls", - r"seq_relationship___cls", - r"cls.predictions", - r"cls.seq_relationship", - ] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward( - MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") - ) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFMultipleChoiceModelOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - flat_inputs_embeds = ( - tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - outputs = self.mobilebert( - flat_input_ids, - flat_attention_mask, - flat_token_type_ids, - flat_position_ids, - head_mask, - flat_inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, training=training) - logits = self.classifier(pooled_output) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mobilebert", None) is not None: - with tf.name_scope(self.mobilebert.name): - self.mobilebert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - MobileBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. - for Named-Entity-Recognition (NER) tasks. - """, - MOBILEBERT_START_DOCSTRING, -) -class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [ - r"pooler", - r"predictions___cls", - r"seq_relationship___cls", - r"cls.predictions", - r"cls.seq_relationship", - ] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert") - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT, - expected_loss=_TOKEN_CLASS_EXPECTED_LOSS, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFTokenClassifierOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.mobilebert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mobilebert", None) is not None: - with tf.name_scope(self.mobilebert.name): - self.mobilebert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFMobileBertForMaskedLM", - "TFMobileBertForMultipleChoice", - "TFMobileBertForNextSentencePrediction", - "TFMobileBertForPreTraining", - "TFMobileBertForQuestionAnswering", - "TFMobileBertForSequenceClassification", - "TFMobileBertForTokenClassification", - "TFMobileBertMainLayer", - "TFMobileBertModel", - "TFMobileBertPreTrainedModel", -] diff --git a/src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index 1b53bbeab475..000000000000 --- a/src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,141 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert MobileNetV1 checkpoints from the tensorflow/models library.""" - -import argparse -import json -import re -from pathlib import Path - -import requests -import torch -from huggingface_hub import hf_hub_download -from PIL import Image - -from transformers import ( - MobileNetV1Config, - MobileNetV1ForImageClassification, - MobileNetV1ImageProcessor, - load_tf_weights_in_mobilenet_v1, -) -from transformers.utils import logging - - -logging.set_verbosity_info() -logger = logging.get_logger(__name__) - - -def get_mobilenet_v1_config(model_name): - config = MobileNetV1Config(layer_norm_eps=0.001) - - if "_quant" in model_name: - raise ValueError("Quantized models are not supported.") - - matches = re.match(r"^mobilenet_v1_([^_]*)_([^_]*)$", model_name) - if matches: - config.depth_multiplier = float(matches[1]) - config.image_size = int(matches[2]) - - # The TensorFlow version of MobileNetV1 predicts 1001 classes instead of - # the usual 1000. The first class (index 0) is "background". - config.num_labels = 1001 - filename = "imagenet-1k-id2label.json" - repo_id = "huggingface/label-files" - id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) - id2label = {int(k) + 1: v for k, v in id2label.items()} - id2label[0] = "background" - config.id2label = id2label - config.label2id = {v: k for k, v in id2label.items()} - - return config - - -# We will verify our results on an image of cute cats -def prepare_img(): - url = "http://images.cocodataset.org/val2017/000000039769.jpg" - im = Image.open(requests.get(url, stream=True).raw) - return im - - -@torch.no_grad() -def convert_movilevit_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub=False): - """ - Copy/paste/tweak model's weights to our MobileNetV1 structure. - """ - config = get_mobilenet_v1_config(model_name) - - # Load 🤗 model - model = MobileNetV1ForImageClassification(config).eval() - - # Load weights from TensorFlow checkpoint - load_tf_weights_in_mobilenet_v1(model, config, checkpoint_path) - - # Check outputs on an image, prepared by MobileNetV1ImageProcessor - image_processor = MobileNetV1ImageProcessor( - crop_size={"width": config.image_size, "height": config.image_size}, - size={"shortest_edge": config.image_size + 32}, - ) - encoding = image_processor(images=prepare_img(), return_tensors="pt") - outputs = model(**encoding) - logits = outputs.logits - - assert logits.shape == (1, 1001) - - if model_name == "mobilenet_v1_1.0_224": - expected_logits = torch.tensor([-4.1739, -1.1233, 3.1205]) - elif model_name == "mobilenet_v1_0.75_192": - expected_logits = torch.tensor([-3.9440, -2.3141, -0.3333]) - else: - expected_logits = None - - if expected_logits is not None: - assert torch.allclose(logits[0, :3], expected_logits, atol=1e-4) - - Path(pytorch_dump_folder_path).mkdir(exist_ok=True) - print(f"Saving model {model_name} to {pytorch_dump_folder_path}") - model.save_pretrained(pytorch_dump_folder_path) - print(f"Saving image processor to {pytorch_dump_folder_path}") - image_processor.save_pretrained(pytorch_dump_folder_path) - - if push_to_hub: - print("Pushing to the hub...") - repo_id = "google/" + model_name - image_processor.push_to_hub(repo_id) - model.push_to_hub(repo_id) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--model_name", - default="mobilenet_v1_1.0_224", - type=str, - help="Name of the MobileNetV1 model you'd like to convert. Should in the form 'mobilenet_v1__'.", - ) - parser.add_argument( - "--checkpoint_path", required=True, type=str, help="Path to the original TensorFlow checkpoint (.ckpt file)." - ) - parser.add_argument( - "--pytorch_dump_folder_path", required=True, type=str, help="Path to the output PyTorch model directory." - ) - parser.add_argument( - "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." - ) - - args = parser.parse_args() - convert_movilevit_checkpoint( - args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub - ) diff --git a/src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index 1fdb9783ccf0..000000000000 --- a/src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,177 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert MobileNetV2 checkpoints from the tensorflow/models library.""" - -import argparse -import json -import re -from pathlib import Path - -import requests -import torch -from huggingface_hub import hf_hub_download -from PIL import Image - -from transformers import ( - MobileNetV2Config, - MobileNetV2ForImageClassification, - MobileNetV2ForSemanticSegmentation, - MobileNetV2ImageProcessor, - load_tf_weights_in_mobilenet_v2, -) -from transformers.utils import logging - - -logging.set_verbosity_info() -logger = logging.get_logger(__name__) - - -def get_mobilenet_v2_config(model_name): - config = MobileNetV2Config(layer_norm_eps=0.001) - - if "quant" in model_name: - raise ValueError("Quantized models are not supported.") - - matches = re.match(r"^.*mobilenet_v2_([^_]*)_([^_]*)$", model_name) - if matches: - config.depth_multiplier = float(matches[1]) - config.image_size = int(matches[2]) - - if model_name.startswith("deeplabv3_"): - config.output_stride = 8 - config.num_labels = 21 - filename = "pascal-voc-id2label.json" - else: - # The TensorFlow version of MobileNetV2 predicts 1001 classes instead - # of the usual 1000. The first class (index 0) is "background". - config.num_labels = 1001 - filename = "imagenet-1k-id2label.json" - - repo_id = "huggingface/label-files" - id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) - - if config.num_labels == 1001: - id2label = {int(k) + 1: v for k, v in id2label.items()} - id2label[0] = "background" - else: - id2label = {int(k): v for k, v in id2label.items()} - - config.id2label = id2label - config.label2id = {v: k for k, v in id2label.items()} - - return config - - -# We will verify our results on an image of cute cats -def prepare_img(): - url = "http://images.cocodataset.org/val2017/000000039769.jpg" - im = Image.open(requests.get(url, stream=True).raw) - return im - - -@torch.no_grad() -def convert_movilevit_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub=False): - """ - Copy/paste/tweak model's weights to our MobileNetV2 structure. - """ - config = get_mobilenet_v2_config(model_name) - - # Load 🤗 model - if model_name.startswith("deeplabv3_"): - model = MobileNetV2ForSemanticSegmentation(config).eval() - else: - model = MobileNetV2ForImageClassification(config).eval() - - # Load weights from TensorFlow checkpoint - load_tf_weights_in_mobilenet_v2(model, config, checkpoint_path) - - # Check outputs on an image, prepared by MobileNetV2ImageProcessor - image_processor = MobileNetV2ImageProcessor( - crop_size={"width": config.image_size, "height": config.image_size}, - size={"shortest_edge": config.image_size + 32}, - ) - encoding = image_processor(images=prepare_img(), return_tensors="pt") - outputs = model(**encoding) - logits = outputs.logits - - if model_name.startswith("deeplabv3_"): - assert logits.shape == (1, 21, 65, 65) - - if model_name == "deeplabv3_mobilenet_v2_1.0_513": - expected_logits = torch.tensor( - [ - [[17.5790, 17.7581, 18.3355], [18.3257, 18.4230, 18.8973], [18.6169, 18.8650, 19.2187]], - [[-2.1595, -2.0977, -2.3741], [-2.4226, -2.3028, -2.6835], [-2.7819, -2.5991, -2.7706]], - [[4.2058, 4.8317, 4.7638], [4.4136, 5.0361, 4.9383], [4.5028, 4.9644, 4.8734]], - ] - ) - - else: - raise ValueError(f"Unknown model name: {model_name}") - - assert torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-4) - else: - assert logits.shape == (1, 1001) - - if model_name == "mobilenet_v2_1.4_224": - expected_logits = torch.tensor([0.0181, -1.0015, 0.4688]) - elif model_name == "mobilenet_v2_1.0_224": - expected_logits = torch.tensor([0.2445, -1.1993, 0.1905]) - elif model_name == "mobilenet_v2_0.75_160": - expected_logits = torch.tensor([0.2482, 0.4136, 0.6669]) - elif model_name == "mobilenet_v2_0.35_96": - expected_logits = torch.tensor([0.1451, -0.4624, 0.7192]) - else: - expected_logits = None - - if expected_logits is not None: - assert torch.allclose(logits[0, :3], expected_logits, atol=1e-4) - - Path(pytorch_dump_folder_path).mkdir(exist_ok=True) - print(f"Saving model {model_name} to {pytorch_dump_folder_path}") - model.save_pretrained(pytorch_dump_folder_path) - print(f"Saving image processor to {pytorch_dump_folder_path}") - image_processor.save_pretrained(pytorch_dump_folder_path) - - if push_to_hub: - print("Pushing to the hub...") - repo_id = "google/" + model_name - image_processor.push_to_hub(repo_id) - model.push_to_hub(repo_id) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--model_name", - default="mobilenet_v2_1.0_224", - type=str, - help="Name of the MobileNetV2 model you'd like to convert. Should in the form 'mobilenet_v2__'.", - ) - parser.add_argument( - "--checkpoint_path", required=True, type=str, help="Path to the original TensorFlow checkpoint (.ckpt file)." - ) - parser.add_argument( - "--pytorch_dump_folder_path", required=True, type=str, help="Path to the output PyTorch model directory." - ) - parser.add_argument( - "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." - ) - - args = parser.parse_args() - convert_movilevit_checkpoint( - args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub - ) diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py deleted file mode 100644 index dcad0f302a8e..000000000000 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ /dev/null @@ -1,1376 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE -"""TensorFlow 2.0 MobileViT model.""" - -from __future__ import annotations - -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...file_utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, - TFImageClassifierOutputWithNoAttention, - TFSemanticSegmenterOutputWithNoAttention, -) -from ...modeling_tf_utils import ( - TFPreTrainedModel, - TFSequenceClassificationLoss, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list, stable_softmax -from ...utils import logging -from .configuration_mobilevit import MobileViTConfig - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "MobileViTConfig" - -# Base docstring -_CHECKPOINT_FOR_DOC = "apple/mobilevit-small" -_EXPECTED_OUTPUT_SHAPE = [1, 640, 8, 8] - -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "apple/mobilevit-small" -_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" - - -def make_divisible(value: int, divisor: int = 8, min_value: int | None = None) -> int: - """ - Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the - original TensorFlow repo. It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py - """ - if min_value is None: - min_value = divisor - new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_value < 0.9 * value: - new_value += divisor - return int(new_value) - - -class TFMobileViTConvLayer(keras.layers.Layer): - def __init__( - self, - config: MobileViTConfig, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - groups: int = 1, - bias: bool = False, - dilation: int = 1, - use_normalization: bool = True, - use_activation: bool | str = True, - **kwargs, - ) -> None: - super().__init__(**kwargs) - logger.warning( - f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish " - "to train/fine-tune this model, you need a GPU or a TPU" - ) - - padding = int((kernel_size - 1) / 2) * dilation - self.padding = keras.layers.ZeroPadding2D(padding) - - if out_channels % groups != 0: - raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.") - - self.convolution = keras.layers.Conv2D( - filters=out_channels, - kernel_size=kernel_size, - strides=stride, - padding="VALID", - dilation_rate=dilation, - groups=groups, - use_bias=bias, - name="convolution", - ) - - if use_normalization: - self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") - else: - self.normalization = None - - if use_activation: - if isinstance(use_activation, str): - self.activation = get_tf_activation(use_activation) - elif isinstance(config.hidden_act, str): - self.activation = get_tf_activation(config.hidden_act) - else: - self.activation = config.hidden_act - else: - self.activation = None - self.in_channels = in_channels - self.out_channels = out_channels - - def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: - padded_features = self.padding(features) - features = self.convolution(padded_features) - if self.normalization is not None: - features = self.normalization(features, training=training) - if self.activation is not None: - features = self.activation(features) - return features - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convolution", None) is not None: - with tf.name_scope(self.convolution.name): - self.convolution.build([None, None, None, self.in_channels]) - if getattr(self, "normalization", None) is not None: - if hasattr(self.normalization, "name"): - with tf.name_scope(self.normalization.name): - self.normalization.build([None, None, None, self.out_channels]) - - -class TFMobileViTInvertedResidual(keras.layers.Layer): - """ - Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381 - """ - - def __init__( - self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1, **kwargs - ) -> None: - super().__init__(**kwargs) - expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8) - - if stride not in [1, 2]: - raise ValueError(f"Invalid stride {stride}.") - - self.use_residual = (stride == 1) and (in_channels == out_channels) - - self.expand_1x1 = TFMobileViTConvLayer( - config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1, name="expand_1x1" - ) - - self.conv_3x3 = TFMobileViTConvLayer( - config, - in_channels=expanded_channels, - out_channels=expanded_channels, - kernel_size=3, - stride=stride, - groups=expanded_channels, - dilation=dilation, - name="conv_3x3", - ) - - self.reduce_1x1 = TFMobileViTConvLayer( - config, - in_channels=expanded_channels, - out_channels=out_channels, - kernel_size=1, - use_activation=False, - name="reduce_1x1", - ) - - def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: - residual = features - - features = self.expand_1x1(features, training=training) - features = self.conv_3x3(features, training=training) - features = self.reduce_1x1(features, training=training) - - return residual + features if self.use_residual else features - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "expand_1x1", None) is not None: - with tf.name_scope(self.expand_1x1.name): - self.expand_1x1.build(None) - if getattr(self, "conv_3x3", None) is not None: - with tf.name_scope(self.conv_3x3.name): - self.conv_3x3.build(None) - if getattr(self, "reduce_1x1", None) is not None: - with tf.name_scope(self.reduce_1x1.name): - self.reduce_1x1.build(None) - - -class TFMobileViTMobileNetLayer(keras.layers.Layer): - def __init__( - self, - config: MobileViTConfig, - in_channels: int, - out_channels: int, - stride: int = 1, - num_stages: int = 1, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.layers = [] - for i in range(num_stages): - layer = TFMobileViTInvertedResidual( - config, - in_channels=in_channels, - out_channels=out_channels, - stride=stride if i == 0 else 1, - name=f"layer.{i}", - ) - self.layers.append(layer) - in_channels = out_channels - - def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: - for layer_module in self.layers: - features = layer_module(features, training=training) - return features - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer_module in self.layers: - with tf.name_scope(layer_module.name): - layer_module.build(None) - - -class TFMobileViTSelfAttention(keras.layers.Layer): - def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: - super().__init__(**kwargs) - - if hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size {hidden_size} is not a multiple of the number of attention " - f"heads {config.num_attention_heads}." - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - scale = tf.cast(self.attention_head_size, dtype=tf.float32) - self.scale = tf.math.sqrt(scale) - - self.query = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="query") - self.key = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="key") - self.value = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="value") - - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.hidden_size = hidden_size - - def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: - batch_size = tf.shape(x)[0] - x = tf.reshape(x, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - return tf.transpose(x, perm=[0, 2, 1, 3]) - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - batch_size = tf.shape(hidden_states)[0] - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - attention_scores = attention_scores / self.scale - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs, training=training) - - context_layer = tf.matmul(attention_probs, value_layer) - - context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) - context_layer = tf.reshape(context_layer, shape=(batch_size, -1, self.all_head_size)) - return context_layer - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.hidden_size]) - - -class TFMobileViTSelfOutput(keras.layers.Layer): - def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: - super().__init__(**kwargs) - self.dense = keras.layers.Dense(hidden_size, name="dense") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.hidden_size = hidden_size - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.hidden_size]) - - -class TFMobileViTAttention(keras.layers.Layer): - def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: - super().__init__(**kwargs) - self.attention = TFMobileViTSelfAttention(config, hidden_size, name="attention") - self.dense_output = TFMobileViTSelfOutput(config, hidden_size, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - self_outputs = self.attention(hidden_states, training=training) - attention_output = self.dense_output(self_outputs, training=training) - return attention_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -class TFMobileViTIntermediate(keras.layers.Layer): - def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None: - super().__init__(**kwargs) - self.dense = keras.layers.Dense(intermediate_size, name="dense") - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.hidden_size = hidden_size - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.hidden_size]) - - -class TFMobileViTOutput(keras.layers.Layer): - def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None: - super().__init__(**kwargs) - self.dense = keras.layers.Dense(hidden_size, name="dense") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.intermediate_size = intermediate_size - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = hidden_states + input_tensor - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.intermediate_size]) - - -class TFMobileViTTransformerLayer(keras.layers.Layer): - def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None: - super().__init__(**kwargs) - self.attention = TFMobileViTAttention(config, hidden_size, name="attention") - self.intermediate = TFMobileViTIntermediate(config, hidden_size, intermediate_size, name="intermediate") - self.mobilevit_output = TFMobileViTOutput(config, hidden_size, intermediate_size, name="output") - self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") - self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") - self.hidden_size = hidden_size - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - attention_output = self.attention(self.layernorm_before(hidden_states), training=training) - hidden_states = attention_output + hidden_states - - layer_output = self.layernorm_after(hidden_states) - layer_output = self.intermediate(layer_output) - layer_output = self.mobilevit_output(layer_output, hidden_states, training=training) - return layer_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "mobilevit_output", None) is not None: - with tf.name_scope(self.mobilevit_output.name): - self.mobilevit_output.build(None) - if getattr(self, "layernorm_before", None) is not None: - with tf.name_scope(self.layernorm_before.name): - self.layernorm_before.build([None, None, self.hidden_size]) - if getattr(self, "layernorm_after", None) is not None: - with tf.name_scope(self.layernorm_after.name): - self.layernorm_after.build([None, None, self.hidden_size]) - - -class TFMobileViTTransformer(keras.layers.Layer): - def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int, **kwargs) -> None: - super().__init__(**kwargs) - - self.layers = [] - for i in range(num_stages): - transformer_layer = TFMobileViTTransformerLayer( - config, - hidden_size=hidden_size, - intermediate_size=int(hidden_size * config.mlp_ratio), - name=f"layer.{i}", - ) - self.layers.append(transformer_layer) - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - for layer_module in self.layers: - hidden_states = layer_module(hidden_states, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer_module in self.layers: - with tf.name_scope(layer_module.name): - layer_module.build(None) - - -class TFMobileViTLayer(keras.layers.Layer): - """ - MobileViT block: https://huggingface.co/papers/2110.02178 - """ - - def __init__( - self, - config: MobileViTConfig, - in_channels: int, - out_channels: int, - stride: int, - hidden_size: int, - num_stages: int, - dilation: int = 1, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.patch_width = config.patch_size - self.patch_height = config.patch_size - - if stride == 2: - self.downsampling_layer = TFMobileViTInvertedResidual( - config, - in_channels=in_channels, - out_channels=out_channels, - stride=stride if dilation == 1 else 1, - dilation=dilation // 2 if dilation > 1 else 1, - name="downsampling_layer", - ) - in_channels = out_channels - else: - self.downsampling_layer = None - - self.conv_kxk = TFMobileViTConvLayer( - config, - in_channels=in_channels, - out_channels=in_channels, - kernel_size=config.conv_kernel_size, - name="conv_kxk", - ) - - self.conv_1x1 = TFMobileViTConvLayer( - config, - in_channels=in_channels, - out_channels=hidden_size, - kernel_size=1, - use_normalization=False, - use_activation=False, - name="conv_1x1", - ) - - self.transformer = TFMobileViTTransformer( - config, hidden_size=hidden_size, num_stages=num_stages, name="transformer" - ) - - self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - - self.conv_projection = TFMobileViTConvLayer( - config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1, name="conv_projection" - ) - - self.fusion = TFMobileViTConvLayer( - config, - in_channels=2 * in_channels, - out_channels=in_channels, - kernel_size=config.conv_kernel_size, - name="fusion", - ) - self.hidden_size = hidden_size - - def unfolding(self, features: tf.Tensor) -> tuple[tf.Tensor, dict]: - patch_width, patch_height = self.patch_width, self.patch_height - patch_area = tf.cast(patch_width * patch_height, "int32") - - batch_size = tf.shape(features)[0] - orig_height = tf.shape(features)[1] - orig_width = tf.shape(features)[2] - channels = tf.shape(features)[3] - - new_height = tf.cast(tf.math.ceil(orig_height / patch_height) * patch_height, "int32") - new_width = tf.cast(tf.math.ceil(orig_width / patch_width) * patch_width, "int32") - - interpolate = new_width != orig_width or new_height != orig_height - if interpolate: - # Note: Padding can be done, but then it needs to be handled in attention function. - features = tf.image.resize(features, size=(new_height, new_width), method="bilinear") - - # number of patches along width and height - num_patch_width = new_width // patch_width - num_patch_height = new_height // patch_height - num_patches = num_patch_height * num_patch_width - - # convert from shape (batch_size, orig_height, orig_width, channels) - # to the shape (batch_size * patch_area, num_patches, channels) - features = tf.transpose(features, [0, 3, 1, 2]) - patches = tf.reshape( - features, (batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width) - ) - patches = tf.transpose(patches, [0, 2, 1, 3]) - patches = tf.reshape(patches, (batch_size, channels, num_patches, patch_area)) - patches = tf.transpose(patches, [0, 3, 2, 1]) - patches = tf.reshape(patches, (batch_size * patch_area, num_patches, channels)) - - info_dict = { - "orig_size": (orig_height, orig_width), - "batch_size": batch_size, - "channels": channels, - "interpolate": interpolate, - "num_patches": num_patches, - "num_patches_width": num_patch_width, - "num_patches_height": num_patch_height, - } - return patches, info_dict - - def folding(self, patches: tf.Tensor, info_dict: dict) -> tf.Tensor: - patch_width, patch_height = self.patch_width, self.patch_height - patch_area = int(patch_width * patch_height) - - batch_size = info_dict["batch_size"] - channels = info_dict["channels"] - num_patches = info_dict["num_patches"] - num_patch_height = info_dict["num_patches_height"] - num_patch_width = info_dict["num_patches_width"] - - # convert from shape (batch_size * patch_area, num_patches, channels) - # back to shape (batch_size, channels, orig_height, orig_width) - features = tf.reshape(patches, (batch_size, patch_area, num_patches, -1)) - features = tf.transpose(features, perm=(0, 3, 2, 1)) - features = tf.reshape( - features, (batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width) - ) - features = tf.transpose(features, perm=(0, 2, 1, 3)) - features = tf.reshape( - features, (batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width) - ) - features = tf.transpose(features, perm=(0, 2, 3, 1)) - - if info_dict["interpolate"]: - features = tf.image.resize(features, size=info_dict["orig_size"], method="bilinear") - - return features - - def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: - # reduce spatial dimensions if needed - if self.downsampling_layer: - features = self.downsampling_layer(features, training=training) - - residual = features - - # local representation - features = self.conv_kxk(features, training=training) - features = self.conv_1x1(features, training=training) - - # convert feature map to patches - patches, info_dict = self.unfolding(features) - - # learn global representations - patches = self.transformer(patches, training=training) - patches = self.layernorm(patches) - - # convert patches back to feature maps - features = self.folding(patches, info_dict) - - features = self.conv_projection(features, training=training) - features = self.fusion(tf.concat([residual, features], axis=-1), training=training) - return features - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv_kxk", None) is not None: - with tf.name_scope(self.conv_kxk.name): - self.conv_kxk.build(None) - if getattr(self, "conv_1x1", None) is not None: - with tf.name_scope(self.conv_1x1.name): - self.conv_1x1.build(None) - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, self.hidden_size]) - if getattr(self, "conv_projection", None) is not None: - with tf.name_scope(self.conv_projection.name): - self.conv_projection.build(None) - if getattr(self, "fusion", None) is not None: - with tf.name_scope(self.fusion.name): - self.fusion.build(None) - if getattr(self, "downsampling_layer", None) is not None: - with tf.name_scope(self.downsampling_layer.name): - self.downsampling_layer.build(None) - - -class TFMobileViTEncoder(keras.layers.Layer): - def __init__(self, config: MobileViTConfig, **kwargs) -> None: - super().__init__(**kwargs) - self.config = config - - self.layers = [] - - # segmentation architectures like DeepLab and PSPNet modify the strides - # of the classification backbones - dilate_layer_4 = dilate_layer_5 = False - if config.output_stride == 8: - dilate_layer_4 = True - dilate_layer_5 = True - elif config.output_stride == 16: - dilate_layer_5 = True - - dilation = 1 - - layer_1 = TFMobileViTMobileNetLayer( - config, - in_channels=config.neck_hidden_sizes[0], - out_channels=config.neck_hidden_sizes[1], - stride=1, - num_stages=1, - name="layer.0", - ) - self.layers.append(layer_1) - - layer_2 = TFMobileViTMobileNetLayer( - config, - in_channels=config.neck_hidden_sizes[1], - out_channels=config.neck_hidden_sizes[2], - stride=2, - num_stages=3, - name="layer.1", - ) - self.layers.append(layer_2) - - layer_3 = TFMobileViTLayer( - config, - in_channels=config.neck_hidden_sizes[2], - out_channels=config.neck_hidden_sizes[3], - stride=2, - hidden_size=config.hidden_sizes[0], - num_stages=2, - name="layer.2", - ) - self.layers.append(layer_3) - - if dilate_layer_4: - dilation *= 2 - - layer_4 = TFMobileViTLayer( - config, - in_channels=config.neck_hidden_sizes[3], - out_channels=config.neck_hidden_sizes[4], - stride=2, - hidden_size=config.hidden_sizes[1], - num_stages=4, - dilation=dilation, - name="layer.3", - ) - self.layers.append(layer_4) - - if dilate_layer_5: - dilation *= 2 - - layer_5 = TFMobileViTLayer( - config, - in_channels=config.neck_hidden_sizes[4], - out_channels=config.neck_hidden_sizes[5], - stride=2, - hidden_size=config.hidden_sizes[2], - num_stages=3, - dilation=dilation, - name="layer.4", - ) - self.layers.append(layer_5) - - def call( - self, - hidden_states: tf.Tensor, - output_hidden_states: bool = False, - return_dict: bool = True, - training: bool = False, - ) -> tuple | TFBaseModelOutput: - all_hidden_states = () if output_hidden_states else None - - for i, layer_module in enumerate(self.layers): - hidden_states = layer_module(hidden_states, training=training) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) - - return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer_module in self.layers: - with tf.name_scope(layer_module.name): - layer_module.build(None) - - -@keras_serializable -class TFMobileViTMainLayer(keras.layers.Layer): - config_class = MobileViTConfig - - def __init__(self, config: MobileViTConfig, expand_output: bool = True, **kwargs): - super().__init__(**kwargs) - self.config = config - self.expand_output = expand_output - - self.conv_stem = TFMobileViTConvLayer( - config, - in_channels=config.num_channels, - out_channels=config.neck_hidden_sizes[0], - kernel_size=3, - stride=2, - name="conv_stem", - ) - - self.encoder = TFMobileViTEncoder(config, name="encoder") - - if self.expand_output: - self.conv_1x1_exp = TFMobileViTConvLayer( - config, - in_channels=config.neck_hidden_sizes[5], - out_channels=config.neck_hidden_sizes[6], - kernel_size=1, - name="conv_1x1_exp", - ) - - self.pooler = keras.layers.GlobalAveragePooling2D(data_format="channels_first", name="pooler") - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - pixel_values: tf.Tensor | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFBaseModelOutputWithPooling: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - # shape = (batch_size, in_height, in_width, in_channels=num_channels) - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - - embedding_output = self.conv_stem(pixel_values, training=training) - - encoder_outputs = self.encoder( - embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training - ) - - if self.expand_output: - last_hidden_state = self.conv_1x1_exp(encoder_outputs[0]) - - # Change to NCHW output format to have uniformity in the modules - last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2]) - - # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels) - pooled_output = self.pooler(last_hidden_state) - else: - last_hidden_state = encoder_outputs[0] - # Change to NCHW output format to have uniformity in the modules - last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2]) - pooled_output = None - - if not return_dict: - output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,) - - # Change to NCHW output format to have uniformity in the modules - if not self.expand_output: - remaining_encoder_outputs = encoder_outputs[1:] - remaining_encoder_outputs = tuple( - tf.transpose(h, perm=(0, 3, 1, 2)) for h in remaining_encoder_outputs[0] - ) - remaining_encoder_outputs = (remaining_encoder_outputs,) - return output + remaining_encoder_outputs - else: - return output + encoder_outputs[1:] - - # Change the other hidden state outputs to NCHW as well - if output_hidden_states: - hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]) - - return TFBaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv_stem", None) is not None: - with tf.name_scope(self.conv_stem.name): - self.conv_stem.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build([None, None, None, None]) - if getattr(self, "conv_1x1_exp", None) is not None: - with tf.name_scope(self.conv_1x1_exp.name): - self.conv_1x1_exp.build(None) - - -class TFMobileViTPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = MobileViTConfig - base_model_prefix = "mobilevit" - main_input_name = "pixel_values" - - -MOBILEVIT_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`MobileViTConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -MOBILEVIT_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]`, `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`MobileViTImageProcessor.__call__`] for details. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. -""" - - -@add_start_docstrings( - "The bare MobileViT model outputting raw hidden-states without any specific head on top.", - MOBILEVIT_START_DOCSTRING, -) -class TFMobileViTModel(TFMobileViTPreTrainedModel): - def __init__(self, config: MobileViTConfig, expand_output: bool = True, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.config = config - self.expand_output = expand_output - - self.mobilevit = TFMobileViTMainLayer(config, expand_output=expand_output, name="mobilevit") - - @unpack_inputs - @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPooling, - config_class=_CONFIG_FOR_DOC, - modality="vision", - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - def call( - self, - pixel_values: tf.Tensor | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFBaseModelOutputWithPooling: - output = self.mobilevit(pixel_values, output_hidden_states, return_dict, training=training) - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mobilevit", None) is not None: - with tf.name_scope(self.mobilevit.name): - self.mobilevit.build(None) - - -@add_start_docstrings( - """ - MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for - ImageNet. - """, - MOBILEVIT_START_DOCSTRING, -) -class TFMobileViTForImageClassification(TFMobileViTPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: MobileViTConfig, *inputs, **kwargs) -> None: - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.mobilevit = TFMobileViTMainLayer(config, name="mobilevit") - - # Classifier head - self.dropout = keras.layers.Dropout(config.classifier_dropout_prob) - self.classifier = ( - keras.layers.Dense(config.num_labels, name="classifier") if config.num_labels > 0 else tf.identity - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=TFImageClassifierOutputWithNoAttention, - config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, - ) - def call( - self, - pixel_values: tf.Tensor | None = None, - output_hidden_states: bool | None = None, - labels: tf.Tensor | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFImageClassifierOutputWithNoAttention: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.mobilevit( - pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training - ) - - pooled_output = outputs.pooler_output if return_dict else outputs[1] - - logits = self.classifier(self.dropout(pooled_output, training=training)) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mobilevit", None) is not None: - with tf.name_scope(self.mobilevit.name): - self.mobilevit.build(None) - if getattr(self, "classifier", None) is not None: - if hasattr(self.classifier, "name"): - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.neck_hidden_sizes[-1]]) - - -class TFMobileViTASPPPooling(keras.layers.Layer): - def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int, **kwargs) -> None: - super().__init__(**kwargs) - - self.global_pool = keras.layers.GlobalAveragePooling2D(keepdims=True, name="global_pool") - - self.conv_1x1 = TFMobileViTConvLayer( - config, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=1, - use_normalization=True, - use_activation="relu", - name="conv_1x1", - ) - - def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: - spatial_size = shape_list(features)[1:-1] - features = self.global_pool(features) - features = self.conv_1x1(features, training=training) - features = tf.image.resize(features, size=spatial_size, method="bilinear") - return features - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "global_pool", None) is not None: - with tf.name_scope(self.global_pool.name): - self.global_pool.build([None, None, None, None]) - if getattr(self, "conv_1x1", None) is not None: - with tf.name_scope(self.conv_1x1.name): - self.conv_1x1.build(None) - - -class TFMobileViTASPP(keras.layers.Layer): - """ - ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587 - """ - - def __init__(self, config: MobileViTConfig, **kwargs) -> None: - super().__init__(**kwargs) - - in_channels = config.neck_hidden_sizes[-2] - out_channels = config.aspp_out_channels - - if len(config.atrous_rates) != 3: - raise ValueError("Expected 3 values for atrous_rates") - - self.convs = [] - - in_projection = TFMobileViTConvLayer( - config, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - use_activation="relu", - name="convs.0", - ) - self.convs.append(in_projection) - - self.convs.extend( - [ - TFMobileViTConvLayer( - config, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - dilation=rate, - use_activation="relu", - name=f"convs.{i + 1}", - ) - for i, rate in enumerate(config.atrous_rates) - ] - ) - - pool_layer = TFMobileViTASPPPooling( - config, in_channels, out_channels, name=f"convs.{len(config.atrous_rates) + 1}" - ) - self.convs.append(pool_layer) - - self.project = TFMobileViTConvLayer( - config, - in_channels=5 * out_channels, - out_channels=out_channels, - kernel_size=1, - use_activation="relu", - name="project", - ) - - self.dropout = keras.layers.Dropout(config.aspp_dropout_prob) - - def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: - # since the hidden states were transposed to have `(batch_size, channels, height, width)` - # layout we transpose them back to have `(batch_size, height, width, channels)` layout. - features = tf.transpose(features, perm=[0, 2, 3, 1]) - pyramid = [] - for conv in self.convs: - pyramid.append(conv(features, training=training)) - pyramid = tf.concat(pyramid, axis=-1) - - pooled_features = self.project(pyramid, training=training) - pooled_features = self.dropout(pooled_features, training=training) - return pooled_features - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "project", None) is not None: - with tf.name_scope(self.project.name): - self.project.build(None) - if getattr(self, "convs", None) is not None: - for conv in self.convs: - with tf.name_scope(conv.name): - conv.build(None) - - -class TFMobileViTDeepLabV3(keras.layers.Layer): - """ - DeepLabv3 architecture: https://huggingface.co/papers/1706.05587 - """ - - def __init__(self, config: MobileViTConfig, **kwargs) -> None: - super().__init__(**kwargs) - self.aspp = TFMobileViTASPP(config, name="aspp") - - self.dropout = keras.layers.Dropout(config.classifier_dropout_prob) - - self.classifier = TFMobileViTConvLayer( - config, - in_channels=config.aspp_out_channels, - out_channels=config.num_labels, - kernel_size=1, - use_normalization=False, - use_activation=False, - bias=True, - name="classifier", - ) - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - features = self.aspp(hidden_states[-1], training=training) - features = self.dropout(features, training=training) - features = self.classifier(features, training=training) - return features - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "aspp", None) is not None: - with tf.name_scope(self.aspp.name): - self.aspp.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC. - """, - MOBILEVIT_START_DOCSTRING, -) -class TFMobileViTForSemanticSegmentation(TFMobileViTPreTrainedModel): - def __init__(self, config: MobileViTConfig, **kwargs) -> None: - super().__init__(config, **kwargs) - - self.num_labels = config.num_labels - self.mobilevit = TFMobileViTMainLayer(config, expand_output=False, name="mobilevit") - self.segmentation_head = TFMobileViTDeepLabV3(config, name="segmentation_head") - - def hf_compute_loss(self, logits, labels): - # upsample logits to the images' original size - # `labels` is of shape (batch_size, height, width) - label_interp_shape = shape_list(labels)[1:] - - upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear") - # compute weighted loss - loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") - - def masked_loss(real, pred): - unmasked_loss = loss_fct(real, pred) - mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype) - masked_loss = unmasked_loss * mask - # Reduction strategy in the similar spirit with - # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210 - reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask) - return tf.reshape(reduced_masked_loss, (1,)) - - return masked_loss(labels, upsampled_logits) - - @unpack_inputs - @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSemanticSegmenterOutputWithNoAttention, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: tf.Tensor | None = None, - labels: tf.Tensor | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple | TFSemanticSegmenterOutputWithNoAttention: - r""" - labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*): - Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). - - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, TFMobileViTForSemanticSegmentation - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small") - >>> model = TFMobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small") - - >>> inputs = image_processor(images=image, return_tensors="tf") - - >>> outputs = model(**inputs) - - >>> # logits are of shape (batch_size, num_labels, height, width) - >>> logits = outputs.logits - ```""" - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if labels is not None and not self.config.num_labels > 1: - raise ValueError("The number of labels should be greater than one") - - outputs = self.mobilevit( - pixel_values, - output_hidden_states=True, # we need the intermediate hidden states - return_dict=return_dict, - training=training, - ) - - encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] - - logits = self.segmentation_head(encoder_hidden_states, training=training) - - loss = None - if labels is not None: - loss = self.hf_compute_loss(logits=logits, labels=labels) - - # make logits of shape (batch_size, num_labels, height, width) to - # keep them consistent across APIs - logits = tf.transpose(logits, perm=[0, 3, 1, 2]) - - if not return_dict: - if output_hidden_states: - output = (logits,) + outputs[1:] - else: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSemanticSegmenterOutputWithNoAttention( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states if output_hidden_states else None, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mobilevit", None) is not None: - with tf.name_scope(self.mobilevit.name): - self.mobilevit.build(None) - if getattr(self, "segmentation_head", None) is not None: - with tf.name_scope(self.segmentation_head.name): - self.segmentation_head.build(None) - - -__all__ = [ - "TFMobileViTForImageClassification", - "TFMobileViTForSemanticSegmentation", - "TFMobileViTModel", - "TFMobileViTPreTrainedModel", -] diff --git a/src/transformers/models/mpnet/modeling_tf_mpnet.py b/src/transformers/models/mpnet/modeling_tf_mpnet.py deleted file mode 100644 index 1afea867df35..000000000000 --- a/src/transformers/models/mpnet/modeling_tf_mpnet.py +++ /dev/null @@ -1,1353 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 MPNet model.""" - -from __future__ import annotations - -import math -import warnings - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_mpnet import MPNetConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "microsoft/mpnet-base" -_CONFIG_FOR_DOC = "MPNetConfig" - - -class TFMPNetPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = MPNetConfig - base_model_prefix = "mpnet" - - -class TFMPNetEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position embeddings.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.padding_idx = 1 - self.config = config - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.hidden_size], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - def create_position_ids_from_input_ids(self, input_ids): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding - symbols are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - input_ids: tf.Tensor - Returns: tf.Tensor - """ - mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) - incremental_indices = tf.math.cumsum(mask, axis=1) * mask - - return incremental_indices + self.padding_idx - - def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False): - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if position_ids is None: - if input_ids is not None: - # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids) - else: - position_ids = tf.expand_dims( - tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 - ) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - final_embeddings = inputs_embeds + position_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->MPNet -class TFMPNetPooler(keras.layers.Layer): - def __init__(self, config: MPNetConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFMPNetSelfAttention(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads}" - ) - - self.num_attention_heads = config.num_attention_heads - assert config.hidden_size % config.num_attention_heads == 0 - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.q = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="q" - ) - self.k = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="k" - ) - self.v = keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="v" - ) - self.o = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="o" - ) - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - self.config = config - - def transpose_for_scores(self, x, batch_size): - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - return tf.transpose(x, perm=[0, 2, 1, 3]) - - def call(self, hidden_states, attention_mask, head_mask, output_attentions, position_bias=None, training=False): - batch_size = shape_list(hidden_states)[0] - - q = self.q(hidden_states) - k = self.k(hidden_states) - v = self.v(hidden_states) - - q = self.transpose_for_scores(q, batch_size) - k = self.transpose_for_scores(k, batch_size) - v = self.transpose_for_scores(v, batch_size) - - attention_scores = tf.matmul(q, k, transpose_b=True) - dk = tf.cast(shape_list(k)[-1], attention_scores.dtype) - attention_scores = attention_scores / tf.math.sqrt(dk) - - # Apply relative position embedding (precomputed in MPNetEncoder) if provided. - if position_bias is not None: - attention_scores += position_bias - - if attention_mask is not None: - attention_scores = attention_scores + attention_mask - - attention_probs = stable_softmax(attention_scores, axis=-1) - - attention_probs = self.dropout(attention_probs, training=training) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - c = tf.matmul(attention_probs, v) - c = tf.transpose(c, perm=[0, 2, 1, 3]) - c = tf.reshape(c, (batch_size, -1, self.all_head_size)) - o = self.o(c) - - outputs = (o, attention_probs) if output_attentions else (o,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "q", None) is not None: - with tf.name_scope(self.q.name): - self.q.build([None, None, self.config.hidden_size]) - if getattr(self, "k", None) is not None: - with tf.name_scope(self.k.name): - self.k.build([None, None, self.config.hidden_size]) - if getattr(self, "v", None) is not None: - with tf.name_scope(self.v.name): - self.v.build([None, None, self.config.hidden_size]) - if getattr(self, "o", None) is not None: - with tf.name_scope(self.o.name): - self.o.build([None, None, self.config.hidden_size]) - - -class TFMPNetAttention(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.attn = TFMPNetSelfAttention(config, name="attn") - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.config = config - - def prune_heads(self, heads): - raise NotImplementedError - - def call(self, input_tensor, attention_mask, head_mask, output_attentions, position_bias=None, training=False): - self_outputs = self.attn( - input_tensor, attention_mask, head_mask, output_attentions, position_bias=position_bias, training=training - ) - attention_output = self.LayerNorm(self.dropout(self_outputs[0]) + input_tensor) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attn", None) is not None: - with tf.name_scope(self.attn.name): - self.attn.build(None) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->MPNet -class TFMPNetIntermediate(keras.layers.Layer): - def __init__(self, config: MPNetConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->MPNet -class TFMPNetOutput(keras.layers.Layer): - def __init__(self, config: MPNetConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFMPNetLayer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.attention = TFMPNetAttention(config, name="attention") - self.intermediate = TFMPNetIntermediate(config, name="intermediate") - self.out = TFMPNetOutput(config, name="output") - - def call(self, hidden_states, attention_mask, head_mask, output_attentions, position_bias=None, training=False): - self_attention_outputs = self.attention( - hidden_states, attention_mask, head_mask, output_attentions, position_bias=position_bias, training=training - ) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - intermediate_output = self.intermediate(attention_output) - layer_output = self.out(intermediate_output, attention_output, training=training) - outputs = (layer_output,) + outputs # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "out", None) is not None: - with tf.name_scope(self.out.name): - self.out.build(None) - - -class TFMPNetEncoder(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.n_heads = config.num_attention_heads - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.initializer_range = config.initializer_range - - self.layer = [TFMPNetLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - self.relative_attention_num_buckets = config.relative_attention_num_buckets - - def build(self, input_shape=None): - if self.built: - return - self.built = True - with tf.name_scope("relative_attention_bias"): - self.relative_attention_bias = self.add_weight( - name="embeddings", - shape=[self.relative_attention_num_buckets, self.n_heads], - initializer=get_initializer(self.initializer_range), - ) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - def call( - self, - hidden_states, - attention_mask, - head_mask, - output_attentions, - output_hidden_states, - return_dict, - training=False, - ): - position_bias = self.compute_position_bias(hidden_states) - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask, - head_mask[i], - output_attentions, - position_bias=position_bias, - training=training, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - @staticmethod - def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): - ret = 0 - n = -relative_position - - num_buckets //= 2 - ret += tf.cast(tf.math.less(n, 0), dtype=relative_position.dtype) * num_buckets - n = tf.math.abs(n) - - # now n is in the range [0, inf) - max_exact = num_buckets // 2 - is_small = tf.math.less(n, max_exact) - - val_if_large = max_exact + tf.cast( - tf.math.log(n / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact), - dtype=relative_position.dtype, - ) - - val_if_large = tf.math.minimum(val_if_large, num_buckets - 1) - ret += tf.where(is_small, n, val_if_large) - return ret - - def compute_position_bias(self, x, position_ids=None): - """Compute binned relative position bias""" - input_shape = shape_list(x) - qlen, klen = input_shape[1], input_shape[1] - - if position_ids is not None: - context_position = position_ids[:, :, None] - memory_position = position_ids[:, None, :] - else: - context_position = tf.range(qlen)[:, None] - memory_position = tf.range(klen)[None, :] - - relative_position = memory_position - context_position # shape (qlen, klen) - - rp_bucket = self._relative_position_bucket( - relative_position, - num_buckets=self.relative_attention_num_buckets, - ) - values = tf.gather(self.relative_attention_bias, rp_bucket) # shape (qlen, klen, num_heads) - values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen) - return values - - -@keras_serializable -class TFMPNetMainLayer(keras.layers.Layer): - config_class = MPNetConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.num_hidden_layers = config.num_hidden_layers - self.initializer_range = config.initializer_range - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - self.encoder = TFMPNetEncoder(config, name="encoder") - self.pooler = TFMPNetPooler(config, name="pooler") - # The embeddings must be the last declaration in order to follow the weights order - self.embeddings = TFMPNetEmbeddings(config, name="embeddings") - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(input_shape, 1) - - embedding_output = self.embeddings( - input_ids, - position_ids, - inputs_embeds, - training=training, - ) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.num_hidden_layers - - encoder_outputs = self.encoder( - embedding_output, - extended_attention_mask, - head_mask, - output_attentions, - output_hidden_states, - return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - - -MPNET_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`MPNetConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -MPNET_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare MPNet Model transformer outputting raw hidden-states without any specific head on top.", - MPNET_START_DOCSTRING, -) -class TFMPNetModel(TFMPNetPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.mpnet = TFMPNetMainLayer(config, name="mpnet") - - @unpack_inputs - @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.array | tf.Tensor | None = None, - position_ids: np.array | tf.Tensor | None = None, - head_mask: np.array | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - outputs = self.mpnet( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mpnet", None) is not None: - with tf.name_scope(self.mpnet.name): - self.mpnet.build(None) - - -class TFMPNetLMHead(keras.layers.Layer): - """MPNet head for masked and permuted language modeling""" - - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.act = get_tf_activation("gelu") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - - def get_output_embeddings(self): - return self.decoder - - def set_output_embeddings(self, value): - self.decoder.weight = value - self.decoder.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.layer_norm(hidden_states) - - # project back to size of vocabulary with bias - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -@add_start_docstrings("""MPNet Model with a `language modeling` head on top.""", MPNET_START_DOCSTRING) -class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss): - _keys_to_ignore_on_load_missing = [r"pooler"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.mpnet = TFMPNetMainLayer(config, name="mpnet") - self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings, name="lm_head") - - def get_lm_head(self): - return self.lm_head - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - outputs = self.mpnet( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mpnet", None) is not None: - with tf.name_scope(self.mpnet.name): - self.mpnet.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -class TFMPNetClassificationHead(keras.layers.Layer): - """Head for sentence-level classification tasks.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.out_proj = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" - ) - self.config = config - - def call(self, features, training=False): - x = features[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x, training=training) - x = self.dense(x) - x = self.dropout(x, training=training) - x = self.out_proj(x) - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - MPNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled - output) e.g. for GLUE tasks. - """, - MPNET_START_DOCSTRING, -) -class TFMPNetForSequenceClassification(TFMPNetPreTrainedModel, TFSequenceClassificationLoss): - _keys_to_ignore_on_load_missing = [r"pooler"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.mpnet = TFMPNetMainLayer(config, name="mpnet") - self.classifier = TFMPNetClassificationHead(config, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.array | tf.Tensor | None = None, - position_ids: np.array | tf.Tensor | None = None, - head_mask: np.array | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.mpnet( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - logits = self.classifier(sequence_output, training=training) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mpnet", None) is not None: - with tf.name_scope(self.mpnet.name): - self.mpnet.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - MPNet Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - MPNET_START_DOCSTRING, -) -class TFMPNetForMultipleChoice(TFMPNetPreTrainedModel, TFMultipleChoiceLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.mpnet = TFMPNetMainLayer(config, name="mpnet") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - flat_inputs_embeds = ( - tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - outputs = self.mpnet( - flat_input_ids, - flat_attention_mask, - flat_position_ids, - head_mask, - flat_inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, training=training) - logits = self.classifier(pooled_output) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mpnet", None) is not None: - with tf.name_scope(self.mpnet.name): - self.mpnet.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - MPNet Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - MPNET_START_DOCSTRING, -) -class TFMPNetForTokenClassification(TFMPNetPreTrainedModel, TFTokenClassificationLoss): - _keys_to_ignore_on_load_missing = [r"pooler"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.mpnet = TFMPNetMainLayer(config, name="mpnet") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.mpnet( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mpnet", None) is not None: - with tf.name_scope(self.mpnet.name): - self.mpnet.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - MPNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - MPNET_START_DOCSTRING, -) -class TFMPNetForQuestionAnswering(TFMPNetPreTrainedModel, TFQuestionAnsweringLoss): - _keys_to_ignore_on_load_missing = [r"pooler"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.mpnet = TFMPNetMainLayer(config, name="mpnet") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.array | tf.Tensor | None = None, - position_ids: np.array | tf.Tensor | None = None, - head_mask: np.array | tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: tf.Tensor | None = None, - end_positions: tf.Tensor | None = None, - training: bool = False, - **kwargs, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.mpnet( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions, "end_position": end_positions} - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "mpnet", None) is not None: - with tf.name_scope(self.mpnet.name): - self.mpnet.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFMPNetEmbeddings", - "TFMPNetForMaskedLM", - "TFMPNetForMultipleChoice", - "TFMPNetForQuestionAnswering", - "TFMPNetForSequenceClassification", - "TFMPNetForTokenClassification", - "TFMPNetMainLayer", - "TFMPNetModel", - "TFMPNetPreTrainedModel", -] diff --git a/src/transformers/models/mt5/modeling_flax_mt5.py b/src/transformers/models/mt5/modeling_flax_mt5.py deleted file mode 100644 index 13bd83b75034..000000000000 --- a/src/transformers/models/mt5/modeling_flax_mt5.py +++ /dev/null @@ -1,123 +0,0 @@ -# coding=utf-8 -# Copyright 2021 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax mT5 model.""" - -import jax.numpy as jnp - -from ...utils import logging -from ..t5.modeling_flax_t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model -from .configuration_mt5 import MT5Config - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "T5Config" - - -# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: - """ - Shift input ids one token to the right. - """ - shifted_input_ids = jnp.zeros_like(input_ids) - shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) - shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) - - shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) - return shifted_input_ids - - -class FlaxMT5Model(FlaxT5Model): - r""" - This class overrides [`FlaxT5Model`]. Please check the superclass for the appropriate documentation alongside usage - examples. - - Examples: - - ```python - >>> from transformers import FlaxMT5Model, AutoTokenizer - - >>> model = FlaxMT5Model.from_pretrained("google/mt5-small") - >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") - - >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." - >>> summary = "Weiter Verhandlung in Syrien." - >>> inputs = tokenizer(article, return_tensors="np") - - >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids - - >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=decoder_input_ids) - >>> hidden_states = outputs.last_hidden_state - ```""" - - model_type = "mt5" - config_class = MT5Config - - -class FlaxMT5EncoderModel(FlaxT5EncoderModel): - r""" - This class overrides [`FlaxT5EncoderModel`]. Please check the superclass for the appropriate documentation - alongside usage examples. - - Examples: - - ```python - >>> from transformers import FlaxT5EncoderModel, AutoTokenizer - - >>> model = FlaxT5EncoderModel.from_pretrained("google/mt5-small") - >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") - - >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." - >>> summary = "Weiter Verhandlung in Syrien." - >>> inputs = tokenizer(article, return_tensors="np") - - >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids - - >>> outputs = model(input_ids=inputs["input_ids"]) - >>> hidden_states = outputs.last_hidden_state - ```""" - - model_type = "mt5" - config_class = MT5Config - - -class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration): - r""" - This class overrides [`FlaxT5ForConditionalGeneration`]. Please check the superclass for the appropriate - documentation alongside usage examples. - - Examples: - - ```python - >>> from transformers import FlaxMT5ForConditionalGeneration, AutoTokenizer - - >>> model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small") - >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") - - >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." - >>> summary = "Weiter Verhandlung in Syrien." - >>> inputs = tokenizer(article, return_tensors="np") - - >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids - - >>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids) - >>> logits = outputs.logits - ```""" - - model_type = "mt5" - config_class = MT5Config - - -__all__ = ["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"] diff --git a/src/transformers/models/mt5/modeling_tf_mt5.py b/src/transformers/models/mt5/modeling_tf_mt5.py deleted file mode 100644 index 6152aea0a5ac..000000000000 --- a/src/transformers/models/mt5/modeling_tf_mt5.py +++ /dev/null @@ -1,98 +0,0 @@ -# coding=utf-8 -# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tensorflow mT5 model.""" - -from ...utils import logging -from ..t5.modeling_tf_t5 import TFT5EncoderModel, TFT5ForConditionalGeneration, TFT5Model -from .configuration_mt5 import MT5Config - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "T5Config" - - -class TFMT5Model(TFT5Model): - r""" - This class overrides [`TFT5Model`]. Please check the superclass for the appropriate documentation alongside usage - examples. - - Examples: - - ```python - >>> from transformers import TFMT5Model, AutoTokenizer - - >>> model = TFMT5Model.from_pretrained("google/mt5-small") - >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") - >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." - >>> summary = "Weiter Verhandlung in Syrien." - >>> inputs = tokenizer(article, return_tensors="tf") - >>> labels = tokenizer(text_target=summary, return_tensors="tf") - - >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"]) - >>> hidden_states = outputs.last_hidden_state - ```""" - - model_type = "mt5" - config_class = MT5Config - - -class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration): - r""" - This class overrides [`TFT5ForConditionalGeneration`]. Please check the superclass for the appropriate - documentation alongside usage examples. - - Examples: - - ```python - >>> from transformers import TFMT5ForConditionalGeneration, AutoTokenizer - - >>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small") - >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") - >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." - >>> summary = "Weiter Verhandlung in Syrien." - >>> inputs = tokenizer(article, text_target=summary, return_tensors="tf") - - >>> outputs = model(**inputs) - >>> loss = outputs.loss - ```""" - - model_type = "mt5" - config_class = MT5Config - - -class TFMT5EncoderModel(TFT5EncoderModel): - r""" - This class overrides [`TFT5EncoderModel`]. Please check the superclass for the appropriate documentation alongside - usage examples. - - Examples: - - ```python - >>> from transformers import TFMT5EncoderModel, AutoTokenizer - - >>> model = TFMT5EncoderModel.from_pretrained("google/mt5-small") - >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") - >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." - >>> input_ids = tokenizer(article, return_tensors="tf").input_ids - >>> outputs = model(input_ids) - >>> hidden_state = outputs.last_hidden_state - ```""" - - model_type = "mt5" - config_class = MT5Config - - -__all__ = ["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"] diff --git a/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index 39653e4b1c77..000000000000 --- a/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,60 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The MyT5 authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert MyT5 checkpoint.""" - -import argparse - -from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 -from transformers.utils import logging - - -logging.set_verbosity_info() - - -# Copied from transformers.models.t5.convert_t5_original_tf_checkpoint_to_pytorch.convert_tf_checkpoint_to_pytorch -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): - # Initialise PyTorch model - config = T5Config.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") - model = T5ForConditionalGeneration(config) - - # Load weights from tf checkpoint - load_tf_weights_in_t5(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained MyT5 model. \nThis specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py deleted file mode 100755 index 3d5218c20426..000000000000 --- a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,74 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert OpenAI GPT checkpoint.""" - -import argparse - -import torch - -from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt -from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging - - -logging.set_verbosity_info() - - -def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): - # Construct model - if openai_config_file == "": - config = OpenAIGPTConfig() - else: - config = OpenAIGPTConfig.from_json_file(openai_config_file) - model = OpenAIGPTModel(config) - - # Load weights from numpy - load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) - - # Save pytorch-model - pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME - pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME - print(f"Save PyTorch model to {pytorch_weights_dump_path}") - torch.save(model.state_dict(), pytorch_weights_dump_path) - print(f"Save configuration file to {pytorch_config_dump_path}") - with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: - f.write(config.to_json_string()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--openai_checkpoint_folder_path", - default=None, - type=str, - required=True, - help="Path to the TensorFlow checkpoint path.", - ) - parser.add_argument( - "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - parser.add_argument( - "--openai_config_file", - default="", - type=str, - help=( - "An optional config json file corresponding to the pre-trained OpenAI model. \n" - "This specifies the model architecture." - ), - ) - args = parser.parse_args() - convert_openai_checkpoint_to_pytorch( - args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path - ) diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py deleted file mode 100644 index 0235159633b4..000000000000 --- a/src/transformers/models/openai/modeling_tf_openai.py +++ /dev/null @@ -1,936 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 OpenAI GPT model.""" - -from __future__ import annotations - -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFConv1D, - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - TFSequenceSummary, - TFSharedEmbeddings, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_openai import OpenAIGPTConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "openai-community/openai-gpt" -_CONFIG_FOR_DOC = "OpenAIGPTConfig" - - -class TFAttention(keras.layers.Layer): - def __init__(self, nx, config, scale=False, **kwargs): - super().__init__(**kwargs) - - n_state = nx # in Attention: n_state=768 (nx=n_embd) - # [switch nx => n_state from Block to Attention to keep identical to TF implementation] - assert n_state % config.n_head == 0, ( - f"Hidden dimension {n_state} not dividable by number of heads {config.n_head}" - ) - self.n_head = config.n_head - self.split_size = n_state - self.scale = scale - self.output_attentions = config.output_attentions - - self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn") - self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj") - self.attn_dropout = keras.layers.Dropout(config.attn_pdrop) - self.resid_dropout = keras.layers.Dropout(config.resid_pdrop) - self.n_state = n_state - self.pruned_heads = set() - - def prune_heads(self, heads): - pass - - @staticmethod - def causal_attention_mask(nd, ns): - """ - 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), - -1, ns-nd), but doesn't produce garbage on TPUs. - """ - i = tf.range(nd)[:, None] - j = tf.range(ns) - m = i >= j - ns + nd - return m - - def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False): - # q, k, v have shape [batch, heads, sequence, features] - w = tf.matmul(q, k, transpose_b=True) - if self.scale: - dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores - w = w / tf.math.sqrt(dk) - - # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. - _, _, nd, ns = shape_list(w) - b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype) - b = tf.reshape(b, [1, 1, nd, ns]) - w = w * b - 1e4 * (1 - b) - - if attention_mask is not None: - # Apply the attention mask - attention_mask = tf.cast(attention_mask, dtype=w.dtype) - w = w + attention_mask - - w = stable_softmax(w, axis=-1) - w = self.attn_dropout(w, training=training) - - # Mask heads if we want to - if head_mask is not None: - w = w * head_mask - - outputs = [tf.matmul(w, v)] - if output_attentions: - outputs.append(w) - return outputs - - def merge_heads(self, x): - x = tf.transpose(x, [0, 2, 1, 3]) - x_shape = shape_list(x) - new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]] - return tf.reshape(x, new_x_shape) - - def split_heads(self, x): - x_shape = shape_list(x) - new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head] - x = tf.reshape(x, new_x_shape) - return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) - - def call(self, x, attention_mask, head_mask, output_attentions, training=False): - x = self.c_attn(x) - query, key, value = tf.split(x, 3, axis=2) - query = self.split_heads(query) - key = self.split_heads(key) - value = self.split_heads(value) - - attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training) - a = attn_outputs[0] - - a = self.merge_heads(a) - a = self.c_proj(a) - a = self.resid_dropout(a, training=training) - - outputs = [a] + attn_outputs[1:] - return outputs # a, (attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "c_attn", None) is not None: - with tf.name_scope(self.c_attn.name): - self.c_attn.build([None, None, self.n_state * 3]) - if getattr(self, "c_proj", None) is not None: - with tf.name_scope(self.c_proj.name): - self.c_proj.build([None, None, self.n_state]) - - -class TFMLP(keras.layers.Layer): - def __init__(self, n_state, config, **kwargs): - super().__init__(**kwargs) - nx = config.n_embd - self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc") - self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj") - self.act = get_tf_activation("gelu") - self.dropout = keras.layers.Dropout(config.resid_pdrop) - self.nx = nx - self.n_state = n_state - - def call(self, x, training=False): - h = self.act(self.c_fc(x)) - h2 = self.c_proj(h) - h2 = self.dropout(h2, training=training) - return h2 - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "c_fc", None) is not None: - with tf.name_scope(self.c_fc.name): - self.c_fc.build([None, None, self.n_state]) - if getattr(self, "c_proj", None) is not None: - with tf.name_scope(self.c_proj.name): - self.c_proj.build([None, None, self.nx]) - - -class TFBlock(keras.layers.Layer): - def __init__(self, config, scale=False, **kwargs): - super().__init__(**kwargs) - nx = config.n_embd - self.attn = TFAttention(nx, config, scale, name="attn") - self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1") - self.mlp = TFMLP(4 * nx, config, name="mlp") - self.ln_2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2") - self.nx = nx - - def call(self, x, attention_mask, head_mask, output_attentions, training=False): - output_attn = self.attn(x, attention_mask, head_mask, output_attentions, training=training) - a = output_attn[0] # output_attn: a, (attentions) - - n = self.ln_1(x + a) - m = self.mlp(n, training=training) - h = self.ln_2(n + m) - - outputs = [h] + output_attn[1:] - return outputs # x, (attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attn", None) is not None: - with tf.name_scope(self.attn.name): - self.attn.build(None) - if getattr(self, "ln_1", None) is not None: - with tf.name_scope(self.ln_1.name): - self.ln_1.build([None, None, self.nx]) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "ln_2", None) is not None: - with tf.name_scope(self.ln_2.name): - self.ln_2.build([None, None, self.nx]) - - -@keras_serializable -class TFOpenAIGPTMainLayer(keras.layers.Layer): - config_class = OpenAIGPTConfig - - def __init__(self, config, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - self.config = config - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - self.return_dict = config.use_return_dict - self.num_hidden_layers = config.n_layer - self.n_embd = config.n_embd - self.n_positions = config.n_positions - self.initializer_range = config.initializer_range - - self.tokens_embed = TFSharedEmbeddings( - config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed" - ) - self.drop = keras.layers.Dropout(config.embd_pdrop) - self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)] - - def build(self, input_shape=None): - with tf.name_scope("positions_embed"): - self.positions_embed = self.add_weight( - name="embeddings", - shape=[self.n_positions, self.n_embd], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "tokens_embed", None) is not None: - with tf.name_scope(self.tokens_embed.name): - self.tokens_embed.build(None) - if getattr(self, "h", None) is not None: - for layer in self.h: - with tf.name_scope(layer.name): - layer.build(None) - - def get_input_embeddings(self): - return self.tokens_embed - - def set_input_embeddings(self, value): - self.tokens_embed.weight = value - self.tokens_embed.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutput: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(input_shape[-1]), axis=0) - - if attention_mask is not None: - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - - one_cst = tf.constant(1.0) - attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype) - attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0)) - else: - attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.num_hidden_layers - # head_mask = tf.constant([0] * self.num_hidden_layers) - - position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = self.tokens_embed(input_ids, mode="embedding") - position_embeds = tf.gather(self.positions_embed, position_ids) - if token_type_ids is not None: - token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) - check_embeddings_within_bounds(token_type_ids, self.config.vocab_size, "token_type_ids") - token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding") - else: - token_type_embeds = 0 - hidden_states = inputs_embeds + position_embeds + token_type_embeds - hidden_states = self.drop(hidden_states, training=training) - - output_shape = input_shape + [shape_list(hidden_states)[-1]] - - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, block in enumerate(self.h): - if output_hidden_states: - all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) - - outputs = block( - hidden_states, - attention_mask, - head_mask[i], - output_attentions, - training=training, - ) - hidden_states = outputs[0] - if output_attentions: - all_attentions = all_attentions + (outputs[1],) - - hidden_states = tf.reshape(hidden_states, output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if output_attentions: - # let the number of heads free (-1) so we can extract attention even after head pruning - attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] - all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - -class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = OpenAIGPTConfig - base_model_prefix = "transformer" - - -@dataclass -class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput): - """ - Base class for outputs of models predicting if two sentences are consecutive or not. - - Args: - logits (`tf.Tensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - mc_logits (`tf.Tensor` of shape `(batch_size, num_choices)`): - Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: tf.Tensor | None = None - mc_logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -OPENAI_GPT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`OpenAIGPTConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -OPENAI_GPT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`tf.Tensor` or `Numpy array` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.", - OPENAI_GPT_START_DOCSTRING, -) -class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutput: - outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -@add_start_docstrings( - """ - OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - OPENAI_GPT_START_DOCSTRING, -) -class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") - # OpenAIGPT does not have past caching features - self.supports_xla_generation = False - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - @unpack_inputs - @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFCausalLMOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - - logits = self.transformer.tokens_embed(hidden_states, mode="linear") - - loss = None - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels, shifted_logits) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutput( - loss=loss, - logits=logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def prepare_inputs_for_generation(self, inputs, **kwargs): - return {"input_ids": inputs} - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -@add_start_docstrings( - """ - OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for - RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the - input embeddings, the classification head takes as input the input of a specified classification token index in the - input sequence). - """, - OPENAI_GPT_START_DOCSTRING, -) -class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - config.num_labels = 1 - self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") - self.multiple_choice_head = TFSequenceSummary( - config, initializer_range=config.initializer_range, name="multiple_choice_head" - ) - - @unpack_inputs - @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - mc_token_ids: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFOpenAIGPTDoubleHeadsModelOutput: - r""" - mc_token_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): - Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - - 1]`. - - Return: - - Examples: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFOpenAIGPTDoubleHeadsModel - - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt") - >>> model = TFOpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt") - - >>> # Add a [CLS] to the vocabulary (we should train it also!) - >>> tokenizer.add_special_tokens({"cls_token": "[CLS]"}) - >>> model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size - >>> print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary - - >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] - >>> encoding = tokenizer(choices, return_tensors="tf") - >>> inputs = {k: tf.expand_dims(v, 0) for k, v in encoding.items()} - >>> inputs["mc_token_ids"] = tf.constant( - ... [inputs["input_ids"].shape[-1] - 1, inputs["input_ids"].shape[-1] - 1] - ... )[ - ... None, : - ... ] # Batch size 1 - >>> outputs = model(inputs) - >>> lm_prediction_scores, mc_prediction_scores = outputs[:2] - ```""" - - if input_ids is not None: - input_shapes = shape_list(input_ids) - else: - input_shapes = shape_list(inputs_embeds)[:-1] - - seq_length = input_shapes[-1] - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - transformer_outputs = self.transformer( - flat_input_ids, - flat_attention_mask, - flat_token_type_ids, - flat_position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) - if return_dict and output_hidden_states: - # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the - # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged) - all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,) - else: - all_hidden_states = None - lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear") - mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) - mc_logits = tf.squeeze(mc_logits, axis=-1) - - if not return_dict: - return (lm_logits, mc_logits) + transformer_outputs[1:] - - return TFOpenAIGPTDoubleHeadsModelOutput( - logits=lm_logits, - mc_logits=mc_logits, - hidden_states=all_hidden_states, - attentions=transformer_outputs.attentions, - ) - - @property - def input_signature(self): - return { - "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"), - "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"), - "mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"), - } - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "multiple_choice_head", None) is not None: - with tf.name_scope(self.multiple_choice_head.name): - self.multiple_choice_head.build(None) - - -@add_start_docstrings( - """ - The OpenAI GPT Model transformer with a sequence classification head on top (linear layer). - - [`TFOpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal - models (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - OPENAI_GPT_START_DOCSTRING, -) -class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.score = keras.layers.Dense( - config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="score", - use_bias=False, - ) - self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple | TFSequenceClassifierOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - logits_shape = shape_list(logits) - batch_size = logits_shape[0] - - if self.config.pad_token_id is None: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - else: - if input_ids is not None: - token_indices = tf.range(shape_list(input_ids)[-1]) - non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype) - last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1) - else: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - loss = None - - pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1) - - if labels is not None: - if self.config.pad_token_id is None and logits_shape[0] != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - - loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels])) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=pooled_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "score", None) is not None: - with tf.name_scope(self.score.name): - self.score.build([None, None, self.config.n_embd]) - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -__all__ = [ - "TFOpenAIGPTDoubleHeadsModel", - "TFOpenAIGPTForSequenceClassification", - "TFOpenAIGPTLMHeadModel", - "TFOpenAIGPTMainLayer", - "TFOpenAIGPTModel", - "TFOpenAIGPTPreTrainedModel", -] diff --git a/src/transformers/models/opt/modeling_flax_opt.py b/src/transformers/models/opt/modeling_flax_opt.py deleted file mode 100644 index d2f77ecbee26..000000000000 --- a/src/transformers/models/opt/modeling_flax_opt.py +++ /dev/null @@ -1,802 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax OPT model.""" - -from functools import partial -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxMaskedLMOutput -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring -from ...utils import add_start_docstrings, logging -from .configuration_opt import OPTConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/opt-350m" -_CONFIG_FOR_DOC = "OPTConfig" - - -OPT_START_DOCSTRING = r""" - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`OPTConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -OPT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->OPT -class FlaxOPTAttention(nn.Module): - config: OPTConfig - embed_dim: int - num_heads: int - dropout: float = 0.0 - causal: bool = False - bias: bool = True - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self) -> None: - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {self.num_heads})." - ) - - dense = partial( - nn.Dense, - self.embed_dim, - use_bias=self.bias, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() - self.out_proj = dense() - - self.dropout_layer = nn.Dropout(rate=self.dropout) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states: jnp.ndarray, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states) - value_states = self.v_proj(key_value_states) - else: - # self_attention - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class FlaxOPTDecoderLayer(nn.Module): - config: OPTConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.hidden_size - self.self_attn = FlaxOPTAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.num_attention_heads, - dropout=self.config.attention_dropout, - causal=True, - dtype=self.dtype, - ) - self.do_layer_norm_before = self.config.do_layer_norm_before - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.fc1 = nn.Dense( - self.config.ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - init_cache: bool = False, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - init_cache=init_cache, - deterministic=deterministic, - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Fully Connected - hidden_states_shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - residual = hidden_states - - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - hidden_states = (residual + hidden_states).reshape(hidden_states_shape) - - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -class FlaxOPTDecoderLayerCollection(nn.Module): - config: OPTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxOPTDecoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - self.layerdrop = self.config.layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - ): - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - deterministic=deterministic, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - outputs = [hidden_states, all_hidden_states, all_self_attns] - return outputs - - -class FlaxOPTLearnedPositionalEmbedding(nn.Embed): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def setup(self): - self.offset = 2 - self.embedding = self.param( - "embedding", self.embedding_init, (self.num_embeddings + self.offset, self.features), self.param_dtype - ) - - def __call__(self, positions): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - - return super().__call__(positions + self.offset) - - -class FlaxOPTDecoder(nn.Module): - config: OPTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - offset: int = 2 - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.hidden_size - self.padding_idx = self.config.pad_token_id - self.max_target_positions = self.config.max_position_embeddings - - self.embed_tokens = nn.Embed( - self.config.vocab_size, - self.config.word_embed_proj_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - - self.embed_positions = FlaxOPTLearnedPositionalEmbedding( - self.config.max_position_embeddings, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - - if self.config.word_embed_proj_dim != self.config.hidden_size: - self.project_in = nn.Dense(self.config.hidden_size, use_bias=False) - self.project_out = nn.Dense(self.config.word_embed_proj_dim, use_bias=False) - - else: - self.project_in = None - self.project_out = None - - # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility - # with checkpoints that have been fine-tuned before transformers v4.20.1 - # see https://github.com/facebookresearch/metaseq/pull/164 - if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm: - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - else: - self.final_layer_norm = None - - self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) - if self.project_in is not None: - inputs_embeds = self.project_in(inputs_embeds) - - positions = self.embed_positions(position_ids) - - hidden_states = inputs_embeds + positions - - hidden_state, all_hidden_states, attentions = self.layers( - hidden_states, - attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if self.final_layer_norm is not None: - hidden_state = self.final_layer_norm(hidden_state) - - if self.project_out is not None: - hidden_state = self.project_out(hidden_state) - - if output_hidden_states: - all_hidden_states += (hidden_state,) - - outputs = [hidden_state, all_hidden_states, attentions] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_state, - hidden_states=all_hidden_states, - attentions=attentions, - ) - - -class FlaxOPTPreTrainedModel(FlaxPreTrainedModel): - config_class = OPTConfig - base_model_prefix: str = "model" - module_class: nn.Module = None - - def __init__( - self, - config: OPTConfig, - input_shape: tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - position_ids, - return_dict=False, - ) - - random_params = module_init_outputs["params"] - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length), dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - params: Optional[dict] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - dropout_rng: PRNGKey = None, - deterministic: bool = True, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - if position_ids is None: - position_ids = (attention_mask.cumsum(axis=1) * attention_mask) - 1 - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed - # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be - # changed by FlaxOPTAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -class FlaxOPTModule(nn.Module): - config: OPTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.decoder = FlaxOPTDecoder(self.config, dtype=self.dtype) - - def _get_decoder_module(self): - return self.decoder - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - init_cache=False, - ): - decoder_outputs = self.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - init_cache=init_cache, - ) - - if not return_dict: - return decoder_outputs - - return FlaxBaseModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - ) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModel with Bart->OPT -class FlaxOPTModel(FlaxOPTPreTrainedModel): - config: OPTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - module_class = FlaxOPTModule - - -append_call_sample_docstring(FlaxOPTModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC) - - -@add_start_docstrings( - "The bare OPT Model transformer outputting raw hidden-states without any specific head on top.", - OPT_START_DOCSTRING, -) -class FlaxOPTForCausalLMModule(nn.Module): - config: OPTConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.model = FlaxOPTModule(config=self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids, - attention_mask, - position_ids, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxMaskedLMOutput( - logits=lm_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - OPT Model with a language modeling head on top (linear layer with weights tied to the input embeddings) e.g for - autoregressive tasks. - """, - OPT_START_DOCSTRING, -) -class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel): - module_class = FlaxOPTForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyway. - # Thus, we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxOPTForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxBaseModelOutput, - _CONFIG_FOR_DOC, -) - - -__all__ = ["FlaxOPTForCausalLM", "FlaxOPTModel", "FlaxOPTPreTrainedModel"] diff --git a/src/transformers/models/opt/modeling_tf_opt.py b/src/transformers/models/opt/modeling_tf_opt.py deleted file mode 100644 index f996256063c0..000000000000 --- a/src/transformers/models/opt/modeling_tf_opt.py +++ /dev/null @@ -1,1092 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 OPT model.""" - -from __future__ import annotations - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast - -# Public API -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - TFSharedEmbeddings, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_opt import OPTConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/opt-350m" -_CONFIG_FOR_DOC = "OPTConfig" - -# Base model docstring -_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] - -# Causal LM output -_CAUSAL_LM_EXPECTED_OUTPUT = ( - "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." -) - -LARGE_NEGATIVE = -1e8 - - -def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz = input_ids_shape[0] - tgt_len = input_ids_shape[1] - # We need triu with k = 1 but TF expects known compile-time dims for that, so we hack around it - mask = tf.fill((tgt_len, tgt_len), tf.cast(LARGE_NEGATIVE, tf.float32)) - mask = tf.linalg.band_part(mask, 0, -1) - tf.linalg.band_part(mask, 0, 0) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) - - return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -class TFOPTLearnedPositionalEmbedding(keras.layers.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): - # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs) - - def call(self, attention_mask, past_key_values_length: int = 0): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - attention_mask = tf.cast(attention_mask, tf.int64) - - # create positions depending on attention_mask - positions = tf.math.cumsum(attention_mask, axis=1) * attention_mask - 1 - - # cut positions if `past_key_values_length` is > 0 - positions = positions[:, past_key_values_length:] - - return super().call(positions + self.offset) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->OPT -class TFOPTAttention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -class TFOPTDecoderLayer(keras.layers.Layer): - def __init__(self, config: OPTConfig, **kwargs): - super().__init__(**kwargs) - self.do_layer_norm_before = config.do_layer_norm_before - self.embed_dim = config.hidden_size - self.self_attn = TFOPTAttention( - embed_dim=self.embed_dim, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout, - name="self_attn", - is_decoder=True, - ) - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.fc1 = keras.layers.Dense(config.ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: np.ndarray | tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - past_key_value: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - training: bool | None = False, - output_attentions: bool | None = False, - use_cache: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`, *optional*): mask for attention heads in a given layer of size - `(decoder_attention_heads,)` - past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - residual = hidden_states - - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Fully Connected - residual = hidden_states - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - return (hidden_states, self_attn_weights, present_key_value) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -OPT_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`OPTConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare OPT Model outputting raw hidden-states without any specific head on top.", - OPT_START_DOCSTRING, -) -class TFOPTPreTrainedModel(TFPreTrainedModel): - """ - TFOPT Pretrained Model that inheritates from transformers.TFPreTrainedModel - - Args: - config: OPTConfig - """ - - config_class = OPTConfig - base_model_prefix = "model" - - -OPT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@keras_serializable -class TFOPTDecoder(keras.layers.Layer): - config_class = OPTConfig - - def __init__(self, config: OPTConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.padding_idx = config.pad_token_id - self.layerdrop = config.layerdrop - num_embeddings = config.max_position_embeddings - self.embed_tokens = TFSharedEmbeddings( - config.vocab_size, config.word_embed_proj_dim, config.pad_token_id, name="embed_tokens" - ) - self.embed_positions = TFOPTLearnedPositionalEmbedding( - num_embeddings, - config.hidden_size, - name="embed_positions", - ) - - # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility - # with checkpoints that have been fine-tuned before transformers v4.20.1 - # see https://github.com/facebookresearch/metaseq/pull/164 - if config.do_layer_norm_before and not config._remove_final_layer_norm: - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - else: - self.final_layer_norm = None - - if config.word_embed_proj_dim != config.hidden_size: - self.project_out = keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False) - self.project_in = keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False) - - else: - self.project_in = None - self.project_out = None - - self.layers = [TFOPTDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)] - self.dropout = keras.layers.Dropout(config.dropout) - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - def set_input_embeddings(self, new_embeddings): - self.embed_tokens.vocab_size = new_embeddings.shape[0] - self.embed_tokens.weight = new_embeddings - - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length): - # create causal mask - # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - _, seq_length = input_shape - tf.debugging.assert_equal( - seq_length + past_key_values_length, - shape_list(attention_mask)[1], - message="Attention mask shape should be (batch_size, seq_length + past_key_values_length)" - f" but is {shape_list(attention_mask)[1]} with input_ids shape {input_shape} and past length" - f" {past_key_values_length}.", - ) - - expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1]) - if seq_length > 1: - combined_attention_mask = ( - _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + expanded_attn_mask - ) - else: - combined_attention_mask = expanded_attn_mask - - return combined_attention_mask - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPast | tuple[tf.Tensor]: - r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up - decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of - shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing - `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more - control over how to convert `input_ids` indices into associated vectors than the model's internal - embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.vocab_size) - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is None: - attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.bool) - else: - tf.debugging.assert_equal( - shape_list(attention_mask)[1], - past_key_values_length + input_shape[1], - message=( - f"The provided attention mask has length {tf.shape(attention_mask)[1]}, but its length should be " - f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)" - ), - ) - pos_embeds = self.embed_positions(attention_mask, past_key_values_length) - - attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length) - - if self.project_in is not None: - inputs_embeds = self.project_in(inputs_embeds) - - hidden_states = inputs_embeds + pos_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - present_key_values = () if use_cache else None - - # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired - for attn_mask_name, attn_mask in [("head_mask", head_mask)]: - if attn_mask is not None: - tf.debugging.assert_equal( - shape_list(attn_mask)[0], - len(self.layers), - message=( - f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(attn_mask)[0]}." - ), - ) - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - hidden_states, layer_self_attn, present_key_value = decoder_layer( - hidden_states, - attention_mask=attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - past_key_value=past_key_value, - ) - - if use_cache: - present_key_values += (present_key_value,) - - if output_attentions: - all_self_attns += (layer_self_attn,) - - if self.final_layer_norm is not None: - hidden_states = self.final_layer_norm(hidden_states) - - if self.project_out is not None: - hidden_states = self.project_out(hidden_states) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns] if v is not None - ) - - else: - return TFBaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=present_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_tokens", None) is not None: - with tf.name_scope(self.embed_tokens.name): - self.embed_tokens.build(None) - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.config.hidden_size]) - if getattr(self, "project_out", None) is not None: - with tf.name_scope(self.project_out.name): - self.project_out.build([None, None, self.config.hidden_size]) - if getattr(self, "project_in", None) is not None: - with tf.name_scope(self.project_in.name): - self.project_in.build([None, None, self.config.word_embed_proj_dim]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFOPTMainLayer(keras.layers.Layer): - config_class = OPTConfig - - def __init__(self, config: OPTConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.decoder = TFOPTDecoder(config, name="decoder") - - def get_input_embeddings(self): - return self.decoder.embed_tokens - - def set_input_embeddings(self, new_embeddings): - self.decoder.set_input_embeddings(new_embeddings) - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs, - ) -> TFBaseModelOutputWithPast | tuple[tf.Tensor]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.decoder( - input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return outputs - - return TFBaseModelOutputWithPast( - last_hidden_state=outputs.last_hidden_state, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - "The bare TF OPT Model outputting raw hidden-states without any specific head on top.", - OPT_START_DOCSTRING, -) -@keras_serializable -class TFOPTModel(TFOPTPreTrainedModel): - config_class = OPTConfig - - def __init__(self, config: OPTConfig, **kwargs): - super().__init__(config, **kwargs) - self.config = config - self.model = TFOPTMainLayer(config, name="model") - - def get_input_embeddings(self): - return self.model.decoder.embed_tokens - - def set_input_embeddings(self, new_embeddings): - self.model.set_input_embeddings(new_embeddings) - - @unpack_inputs - @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPast, - config_class=_CONFIG_FOR_DOC, - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs, - ) -> TFBaseModelOutputWithPast | tuple[tf.Tensor]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return outputs - - return TFBaseModelOutputWithPast( - last_hidden_state=outputs.last_hidden_state, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None - attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None - - return TFBaseModelOutputWithPast( - last_hidden_state=output.last_hidden_state, - past_key_values=pkv, - hidden_states=hs, - attentions=attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -@add_start_docstrings( - """ - The OPT Model transformer with a language modeling head on top. - """, - OPT_START_DOCSTRING, -) -@keras_serializable -class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss): - config_class = OPTConfig - - def __init__(self, config: OPTConfig, **kwargs): - super().__init__(config, **kwargs) - self.config = config - self.model = TFOPTMainLayer(config, name="model") - - def get_output_embeddings(self): - return self.model.get_input_embeddings() - - def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs): - attention_mask = kwargs.get("attention_mask") - - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - inputs = tf.expand_dims(inputs[:, -1], -1) - - return { - "input_ids": inputs, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "use_cache": use_cache, - } - - @unpack_inputs - @replace_return_docstrings(output_type=TFCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutputWithPast, - config_class=_CONFIG_FOR_DOC, - expected_output=_CAUSAL_LM_EXPECTED_OUTPUT, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - labels: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs, - ) -> TFCausalLMOutputWithPast | tuple[tf.Tensor]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional - tensors are only required when the model is used as a decoder in a Sequence to Sequence model. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - logits = self.model.decoder.embed_tokens(outputs[0], mode="linear") - loss = None - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels, shifted_logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None - attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None - - return TFCausalLMOutputWithPast( - past_key_values=pkv, - hidden_states=hs, - attentions=attns, - loss=output.loss, - logits=output.logits, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -__all__ = ["TFOPTForCausalLM", "TFOPTModel", "TFOPTPreTrainedModel"] diff --git a/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py b/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py deleted file mode 100644 index ea766c366f34..000000000000 --- a/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py +++ /dev/null @@ -1,406 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert OWL-ViT checkpoints from the original repository. URL: -https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit""" - -import argparse -import collections - -import jax -import jax.numpy as jnp -import torch -import torch.nn as nn -from clip.model import CLIP -from flax.training import checkpoints -from huggingface_hub import Repository - -from transformers import ( - CLIPTokenizer, - OwlViTConfig, - OwlViTForObjectDetection, - OwlViTImageProcessor, - OwlViTModel, - OwlViTProcessor, -) - - -CONFIGS = { - "vit_b32": { - "embed_dim": 512, - "image_resolution": 768, - "context_length": 16, - "vocab_size": 49408, - "vision_layers": 12, - "vision_width": 768, - "vision_patch_size": 32, - "transformer_width": 512, - "transformer_heads": 8, - "transformer_layers": 12, - }, - "vit_b16": { - "embed_dim": 512, - "image_resolution": 768, - "context_length": 16, - "vocab_size": 49408, - "vision_layers": 12, - "vision_width": 768, - "vision_patch_size": 16, - "transformer_width": 512, - "transformer_heads": 8, - "transformer_layers": 12, - }, - "vit_l14": { - "embed_dim": 768, - "image_resolution": 840, - "context_length": 16, - "vocab_size": 49408, - "vision_layers": 24, - "vision_width": 1024, - "vision_patch_size": 14, - "transformer_width": 768, - "transformer_heads": 12, - "transformer_layers": 12, - }, -} - - -def flatten_nested_dict(params, parent_key="", sep="/"): - items = [] - - for k, v in params.items(): - new_key = parent_key + sep + k if parent_key else k - - if isinstance(v, collections.MutableMapping): - items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) - else: - items.append((new_key, v)) - return dict(items) - - -def to_f32(params): - return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params) - - -def copy_attn_layer(hf_attn_layer, pt_attn_layer): - q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0) - q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0) - - out_proj_weights = pt_attn_layer.out_proj.weight - out_proj_bias = pt_attn_layer.out_proj.bias - - hf_attn_layer.q_proj.weight.data = q_proj - hf_attn_layer.q_proj.bias.data = q_proj_bias - - hf_attn_layer.k_proj.weight.data = k_proj - hf_attn_layer.k_proj.bias.data = k_proj_bias - - hf_attn_layer.v_proj.weight.data = v_proj - hf_attn_layer.v_proj.bias.data = v_proj_bias - - hf_attn_layer.out_proj.weight = out_proj_weights - hf_attn_layer.out_proj.bias = out_proj_bias - - -def copy_mlp(hf_mlp, pt_mlp): - copy_linear(hf_mlp.fc1, pt_mlp.c_fc) - copy_linear(hf_mlp.fc2, pt_mlp.c_proj) - - -def copy_linear(hf_linear, pt_linear): - hf_linear.weight = pt_linear.weight - hf_linear.bias = pt_linear.bias - - -def copy_layer(hf_layer, pt_layer): - # copy layer norms - copy_linear(hf_layer.layer_norm1, pt_layer.ln_1) - copy_linear(hf_layer.layer_norm2, pt_layer.ln_2) - - # copy MLP - copy_mlp(hf_layer.mlp, pt_layer.mlp) - - # copy attn - copy_attn_layer(hf_layer.self_attn, pt_layer.attn) - - -def copy_layers(hf_layers, pt_layers): - for hf_layer, pt_layer in zip(hf_layers, pt_layers): - copy_layer(hf_layer, pt_layer) - - -def copy_encoder(hf_encoder, pt_model): - # copy embeds - hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight - hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding - - # copy layer norm - copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final) - - # copy hidden layers - copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks) - - -def copy_text_model_and_projection(hf_model, pt_model): - # copy projection - hf_model.text_projection.weight.data = pt_model.text_projection.data.T - - # copy text encoder - copy_encoder(hf_model.text_model, pt_model) - - -def copy_vision_model_and_projection(hf_model, pt_model): - # copy projection - hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T - - # copy layer norms - copy_linear(hf_model.vision_model.pre_layernorm, pt_model.visual.ln_pre) - copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post) - - # copy embeds - hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data - hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding - hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data - - # copy encoder - copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks) - - -def copy_class_merge_token(hf_model, flax_params): - flax_class_token_params = flatten_nested_dict(flax_params["backbone"]["merged_class_token"]) - - weight = torch.from_numpy(flax_class_token_params["scale"]) - bias = torch.from_numpy(flax_class_token_params["bias"]) - hf_model.layer_norm.weight = nn.Parameter(weight) - hf_model.layer_norm.bias = nn.Parameter(bias) - - -def copy_class_box_heads(hf_model, flax_params): - pt_params = hf_model.state_dict() - new_params = {} - - # Rename class prediction head flax params to pytorch HF - flax_class_params = flatten_nested_dict(flax_params["class_head"]) - - for flax_key, v in flax_class_params.items(): - torch_key = flax_key.replace("/", ".") - torch_key = torch_key.replace(".kernel", ".weight") - torch_key = torch_key.replace("Dense_0", "dense0") - torch_key = "class_head." + torch_key - - if "weight" in torch_key and v.ndim == 2: - v = v.T - - new_params[torch_key] = nn.Parameter(torch.from_numpy(v)) - - # Rename box prediction box flax params to pytorch HF - flax_box_params = flatten_nested_dict(flax_params["obj_box_head"]) - - for flax_key, v in flax_box_params.items(): - torch_key = flax_key.replace("/", ".") - torch_key = torch_key.replace(".kernel", ".weight") - torch_key = torch_key.replace("_", "").lower() - torch_key = "box_head." + torch_key - - if "weight" in torch_key and v.ndim == 2: - v = v.T - - new_params[torch_key] = nn.Parameter(torch.from_numpy(v)) - - # Copy flax params to PyTorch params - for name, param in new_params.items(): - if name in pt_params: - pt_params[name].copy_(param) - - -def copy_flax_attn_params(hf_backbone, flax_attn_params): - for k, v in flax_attn_params.items(): - if k.startswith("transformer"): - torch_key = k.replace("transformer.resblocks", "text_model.encoder.layers") - else: - torch_key = k.replace("visual.transformer.resblocks", "vision_model.encoder.layers") - - torch_key = torch_key.replace("attn", "self_attn") - torch_key = torch_key.replace("key", "k_proj") - torch_key = torch_key.replace("value", "v_proj") - torch_key = torch_key.replace("query", "q_proj") - torch_key = torch_key.replace("out", "out_proj") - - if "bias" in torch_key and v.ndim == 2: - shape = v.shape[0] * v.shape[1] - v = v.reshape(shape) - - if "weight" in torch_key and "out" in torch_key: - shape = (v.shape[0] * v.shape[1], v.shape[2]) - v = v.reshape(shape).T - - if "weight" in torch_key and "out" not in torch_key: - shape = (v.shape[0], v.shape[1] * v.shape[2]) - v = v.reshape(shape).T - - # Copy flax CLIP attn params to HF PyTorch params - v = torch.from_numpy(v) - hf_backbone.state_dict()[torch_key].copy_(v) - - -def _convert_attn_layers(params): - new_params = {} - processed_attn_layers = [] - - for k, v in params.items(): - if "attn." in k: - base = k[: k.rindex("attn.") + 5] - if base in processed_attn_layers: - continue - - processed_attn_layers.append(base) - dim = params[base + "out.weight"].shape[-1] - new_params[base + "out_proj.weight"] = params[base + "out.weight"].reshape(dim, dim).T - new_params[base + "out_proj.bias"] = params[base + "out.bias"] - else: - new_params[k] = v - return new_params - - -def convert_clip_backbone(flax_params, torch_config): - torch_model = CLIP(**torch_config) - torch_model.eval() - torch_clip_params = torch_model.state_dict() - - flax_clip_params = flatten_nested_dict(flax_params["backbone"]["clip"]) - new_torch_params = {} - - for flax_key, v in flax_clip_params.items(): - torch_key = flax_key.replace("/", ".") - torch_key = torch_key.replace("text.token_embedding.embedding", "token_embedding.kernel") - - if ( - torch_key.startswith("text.transformer") - or torch_key.startswith("text.text_projection") - or torch_key.startswith("text.ln_final") - or torch_key.startswith("text.positional_embedding") - ): - torch_key = torch_key[5:] - - torch_key = torch_key.replace("text_projection.kernel", "text_projection") - torch_key = torch_key.replace("visual.proj.kernel", "visual.proj") - torch_key = torch_key.replace(".scale", ".weight") - torch_key = torch_key.replace(".kernel", ".weight") - - if "conv" in torch_key or "downsample.0.weight" in torch_key: - v = v.transpose(3, 2, 0, 1) - - elif "weight" in torch_key and v.ndim == 2 and "embedding" not in torch_key: - # Fully connected layers are transposed, embeddings are not - v = v.T - - new_torch_params[torch_key] = v - - attn_params = _convert_attn_layers(new_torch_params) - new_torch_params.update(attn_params) - attn_params = {} - - # Copy flax CLIP backbone params to PyTorch params - for name, param in new_torch_params.items(): - if name in torch_clip_params: - new_param = torch.from_numpy(param) - torch_clip_params[name].copy_(new_param) - else: - attn_params[name] = param - - return torch_clip_params, torch_model, attn_params - - -@torch.no_grad() -def convert_owlvit_checkpoint(pt_backbone, flax_params, attn_params, pytorch_dump_folder_path, config_path=None): - """ - Copy/paste/tweak model's weights to transformers design. - """ - repo = Repository(pytorch_dump_folder_path, clone_from=f"google/{pytorch_dump_folder_path}") - repo.git_pull() - - if config_path is not None: - config = OwlViTConfig.from_pretrained(config_path) - else: - config = OwlViTConfig() - - hf_backbone = OwlViTModel(config).eval() - hf_model = OwlViTForObjectDetection(config).eval() - - copy_text_model_and_projection(hf_backbone, pt_backbone) - copy_vision_model_and_projection(hf_backbone, pt_backbone) - hf_backbone.logit_scale = pt_backbone.logit_scale - copy_flax_attn_params(hf_backbone, attn_params) - - hf_model.owlvit = hf_backbone - copy_class_merge_token(hf_model, flax_params) - copy_class_box_heads(hf_model, flax_params) - - # Save HF model - hf_model.save_pretrained(repo.local_dir) - - # Initialize image processor - image_processor = OwlViTImageProcessor( - size=config.vision_config.image_size, crop_size=config.vision_config.image_size - ) - # Initialize tokenizer - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", pad_token="!", model_max_length=16) - - # Initialize processor - processor = OwlViTProcessor(image_processor=image_processor, tokenizer=tokenizer) - image_processor.save_pretrained(repo.local_dir) - processor.save_pretrained(repo.local_dir) - - repo.git_add() - repo.git_commit("Upload model and processor") - repo.git_push() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--owlvit_version", - default=None, - type=str, - required=True, - help="OWL-ViT model name [clip_b16, clip_b32, clip_l14].", - ) - parser.add_argument( - "--owlvit_checkpoint", default=None, type=str, required=True, help="Path to flax model checkpoint." - ) - parser.add_argument("--hf_config", default=None, type=str, required=True, help="Path to HF model config.") - parser.add_argument( - "--pytorch_dump_folder_path", default="hf_model", type=str, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - - # Initialize PyToch clip model - model_name = args.owlvit_version - if model_name == "clip_b16": - torch_config = CONFIGS["vit_b16"] - elif model_name == "clip_b32": - torch_config = CONFIGS["vit_b32"] - elif model_name == "clip_l14": - torch_config = CONFIGS["vit_l14"] - - # Load from checkpoint and convert params to float-32 - variables = checkpoints.restore_checkpoint(args.owlvit_checkpoint, target=None)["optimizer"]["target"] - flax_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables) - del variables - - # Convert CLIP backbone - pt_backbone_params, clip_pt, attn_params = convert_clip_backbone(flax_params, torch_config) - - convert_owlvit_checkpoint(clip_pt, flax_params, attn_params, args.pytorch_dump_folder_path, args.hf_config) diff --git a/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py b/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py deleted file mode 100644 index 9251c9a92ac6..000000000000 --- a/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py +++ /dev/null @@ -1,130 +0,0 @@ -# coding=utf-8 -# Copyright 2020 Google and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -from pathlib import Path - -import tensorflow as tf -import torch -from tqdm import tqdm - -from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer -from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params - - -PATTERNS = [ - # replace left string with right string to get the relevant state_dict key (identical state dict to bart) - ["memory_attention", "encoder_attn"], - ["attention", "attn"], - ["/", "."], - [".LayerNorm.gamma", "_layer_norm.weight"], - [".LayerNorm.beta", "_layer_norm.bias"], - ["r.layer_", "r.layers."], - ["output_proj", "out_proj"], - ["ffn.dense_1.", "fc2."], - ["ffn.dense.", "fc1."], - ["ffn_layer_norm", "final_layer_norm"], - ["kernel", "weight"], - ["encoder_layer_norm.", "encoder.layer_norm."], - ["decoder_layer_norm.", "decoder.layer_norm."], - ["embeddings.weights", "shared.weight"], -] - - -def rename_state_dict_key(k): - for pegasus_name, hf_name in PATTERNS: - k = k.replace(pegasus_name, hf_name) - return k - - -# See appendix C of paper for all hyperparams - - -def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration: - cfg_kwargs = DEFAULTS.copy() - cfg_kwargs.update(cfg_updates) - cfg = PegasusConfig(**cfg_kwargs) - torch_model = PegasusForConditionalGeneration(cfg) - sd = torch_model.model.state_dict() - mapping = {} - for k, v in tf_weights.items(): - new_k = rename_state_dict_key(k) - if new_k not in sd: - raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") - - if "dense" in k or "proj" in new_k: - v = v.T - mapping[new_k] = torch.tensor(v, dtype=sd[new_k].dtype) - assert v.shape == sd[new_k].shape, f"{new_k}, {k}, {v.shape}, {sd[new_k].shape}" - # make sure embedding.padding_idx is respected - mapping["shared.weight"][cfg.pad_token_id] = torch.zeros_like(mapping["shared.weight"][cfg.pad_token_id + 1]) - mapping["encoder.embed_tokens.weight"] = mapping["shared.weight"] - mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"] - empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping} - mapping.update(**empty_biases) - missing, extra = torch_model.model.load_state_dict(mapping, strict=False) - unexpected_missing = [ - k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"] - ] - assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}" - assert extra == [], f"no matches found for the following tf keys {extra}" - return torch_model - - -def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> dict: - init_vars = tf.train.list_variables(path) - tf_weights = {} - ignore_name = ["Adafactor", "global_step"] - for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"): - skip_key = any(pat in name for pat in ignore_name) - if skip_key: - continue - array = tf.train.load_variable(path, name) - tf_weights[name] = array - return tf_weights - - -def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str): - # save tokenizer first - dataset = Path(ckpt_path).parent.name - desired_max_model_length = task_specific_params[f"summarization_{dataset}"]["max_position_embeddings"] - tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length) - assert tok.model_max_length == desired_max_model_length - tok.save_pretrained(save_dir) - - # convert model - tf_weights = get_tf_weights_as_numpy(ckpt_path) - cfg_updates = task_specific_params[f"summarization_{dataset}"] - if dataset == "large": - cfg_updates["task_specific_params"] = task_specific_params - torch_model = convert_pegasus(tf_weights, cfg_updates) - torch_model.save_pretrained(save_dir) - sd = torch_model.state_dict() - sd.pop("model.decoder.embed_positions.weight") - sd.pop("model.encoder.embed_positions.weight") - torch.save(sd, Path(save_dir) / "pytorch_model.bin") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument("tf_ckpt_path", type=str, help="passed to tf.train.list_variables") - parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.") - args = parser.parse_args() - if args.save_dir is None: - dataset = Path(args.tf_ckpt_path).parent.name - args.save_dir = os.path.join("pegasus", dataset) - convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir) diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py deleted file mode 100644 index ddf0ae492407..000000000000 --- a/src/transformers/models/pegasus/modeling_flax_pegasus.py +++ /dev/null @@ -1,1532 +0,0 @@ -# coding=utf-8 -# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax PEGASUS model.""" - -import math -import random -from functools import partial -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxSeq2SeqLMOutput, - FlaxSeq2SeqModelOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - add_start_docstrings_to_model_forward, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, logging, replace_return_docstrings -from .configuration_pegasus import PegasusConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "google/pegasus-large" -_CONFIG_FOR_DOC = "PegasusConfig" - -PEGASUS_START_DOCSTRING = r""" - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`PegasusConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -PEGASUS_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -PEGASUS_ENCODE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -PEGASUS_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: - """ - Shift input ids one token to the right. - """ - shifted_input_ids = jnp.zeros_like(input_ids) - shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) - shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) - - shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) - return shifted_input_ids - - -# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions -def create_sinusoidal_positions(n_pos, dim): - position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) - sentinel = dim // 2 + dim % 2 - out = np.zeros_like(position_enc) - out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) - out[:, sentinel:] = np.cos(position_enc[:, 1::2]) - - return jnp.array(out) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus -class FlaxPegasusAttention(nn.Module): - config: PegasusConfig - embed_dim: int - num_heads: int - dropout: float = 0.0 - causal: bool = False - bias: bool = True - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self) -> None: - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {self.num_heads})." - ) - - dense = partial( - nn.Dense, - self.embed_dim, - use_bias=self.bias, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() - self.out_proj = dense() - - self.dropout_layer = nn.Dropout(rate=self.dropout) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states: jnp.ndarray, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states) - value_states = self.v_proj(key_value_states) - else: - # self_attention - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Pegasus -class FlaxPegasusEncoderLayer(nn.Module): - config: PegasusConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxPegasusAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.encoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->Pegasus -class FlaxPegasusEncoderLayerCollection(nn.Module): - config: PegasusConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxPegasusEncoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.encoder_layers) - ] - self.layerdrop = self.config.encoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): # skip the layer - layer_outputs = (None, None) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions, - deterministic, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Pegasus -class FlaxPegasusDecoderLayer(nn.Module): - config: PegasusConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxPegasusAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - causal=True, - dtype=self.dtype, - ) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.encoder_attn = FlaxPegasusAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.fc1 = nn.Dense( - self.config.decoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Pegasus -class FlaxPegasusDecoderLayerCollection(nn.Module): - config: PegasusConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxPegasusDecoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.decoder_layers) - ] - self.layerdrop = self.config.decoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): - layer_outputs = (None, None, None) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - deterministic=deterministic, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class FlaxPegasusEncoder(nn.Module): - config: PegasusConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_source_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 - - self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim) - self.layers = FlaxPegasusEncoderLayerCollection(self.config, self.dtype) - self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - # embed positions - embed_pos = jnp.take(self.embed_positions, position_ids, axis=0) - # explicitly cast the positions here, since self.embed_positions are not registered as parameters - embed_pos = embed_pos.astype(inputs_embeds.dtype) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - outputs = self.layers( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - last_hidden_state = outputs[0] - last_hidden_state = self.layer_norm(last_hidden_state) - - # update the last element in `hidden_states` after applying `layernorm` above - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_state,) - - if not return_dict: - outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=last_hidden_state, - hidden_states=hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxPegasusDecoder(nn.Module): - config: PegasusConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_target_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 - - self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim) - - self.layers = FlaxPegasusDecoderLayerCollection(self.config, self.dtype) - self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - # embed positions - positions = jnp.take(self.embed_positions, position_ids, axis=0) - # explicitly cast the positions here, since self.embed_positions are not registered as parameters - positions = positions.astype(inputs_embeds.dtype) - - hidden_states = inputs_embeds + positions - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - outputs = self.layers( - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - last_hidden_state = outputs[0] - last_hidden_state = self.layer_norm(last_hidden_state) - - # update the last element in `hidden_states` after applying `layernorm` above - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_state,) - - if not return_dict: - outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=last_hidden_state, - hidden_states=hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->Pegasus -class FlaxPegasusModule(nn.Module): - config: PegasusConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - - self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - self.decoder = FlaxPegasusDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel): - config_class = PegasusConfig - base_model_prefix: str = "model" - module_class: nn.Module = None - - def __init__( - self, - config: PegasusConfig, - input_shape: tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - decoder_input_ids = input_ids - decoder_attention_mask = jnp.ones_like(input_ids) - - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(PEGASUS_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=PegasusConfig) - def encode( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration - - >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") - >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_ids, attention_mask, position_ids, **kwargs) - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=PegasusConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration - - >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") - >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> last_decoder_hidden_states = outputs.last_hidden_state - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxPegasusAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # prepare decoder inputs - if decoder_input_ids is None: - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - if decoder_position_ids is None: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - -@add_start_docstrings( - "The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.", - PEGASUS_START_DOCSTRING, -) -class FlaxPegasusModel(FlaxPegasusPreTrainedModel): - config: PegasusConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - module_class = FlaxPegasusModule - - -append_call_sample_docstring(FlaxPegasusModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) - - -# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->Pegasus -class FlaxPegasusForConditionalGenerationModule(nn.Module): - config: PegasusConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.model = FlaxPegasusModule(config=self.config, dtype=self.dtype) - self.lm_head = nn.Dense( - self.model.shared.num_embeddings, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = self.model.variables["params"]["shared"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype)) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return output - - return FlaxSeq2SeqLMOutput( - logits=lm_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING -) -class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel): - module_class = FlaxPegasusForConditionalGenerationModule - dtype: jnp.dtype = jnp.float32 - - @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=PegasusConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - deterministic: bool = True, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration - - >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") - >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxPegasusAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - outputs = decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = module.model.variables["params"]["shared"]["embedding"] - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = module.lm_head(hidden_states) - - lm_logits += module.final_logits_bias.astype(self.dtype) - return lm_logits, outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 - return model_kwargs - - -FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING = """ - Returns: - - Summarization example: - - ```python - >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration - - >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large') - >>> tokenizer = AutoTokenizer.from_pretrained('google/pegasus-large') - - >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np') - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs['input_ids']).sequences - >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) - ``` - - Mask filling example: - - ```python - >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") - >>> TXT = "My friends are but they eat too many carbs." - - >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") - >>> input_ids = tokenizer([TXT], return_tensors="np")["input_ids"] - >>> logits = model(input_ids).logits - - >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() - >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) - >>> values, predictions = jax.lax.top_k(probs) - - >>> tokenizer.decode(predictions).split() - ``` -""" - -overwrite_call_docstring( - FlaxPegasusForConditionalGeneration, PEGASUS_INPUTS_DOCSTRING + FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING -) -append_replace_return_docstrings( - FlaxPegasusForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC -) - - -__all__ = ["FlaxPegasusForConditionalGeneration", "FlaxPegasusModel", "FlaxPegasusPreTrainedModel"] diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py deleted file mode 100644 index d159fc00138d..000000000000 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ /dev/null @@ -1,1573 +0,0 @@ -# coding=utf-8 -# Copyright 2021, Google Inc. and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 Pegasus model.""" - -from __future__ import annotations - -import random - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPastAndCrossAttentions, - TFSeq2SeqLMOutput, - TFSeq2SeqModelOutput, -) - -# Public API -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_end_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_pegasus import PegasusConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "google/pegasus-large" -_CONFIG_FOR_DOC = "PegasusConfig" - - -LARGE_NEGATIVE = -1e8 - - -# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right -def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - pad_token_id = tf.cast(pad_token_id, input_ids.dtype) - decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) - start_tokens = tf.fill( - (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) - ) - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz = input_ids_shape[0] - tgt_len = input_ids_shape[1] - mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE - mask_cond = tf.range(shape_list(mask)[-1]) - - mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) - - return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -# Copied from transformers.models.marian.modeling_tf_marian.TFMarianSinusoidalPositionalEmbedding with Marian->Pegasus -class TFPegasusSinusoidalPositionalEmbedding(keras.layers.Layer): - """This module produces sinusoidal positional embeddings of any length.""" - - def __init__(self, num_positions: int, embedding_dim: int, **kwargs): - super().__init__(**kwargs) - - if embedding_dim % 2 != 0: - raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") - - self.embedding_dim = embedding_dim - self.num_positions = num_positions - - def build(self, input_shape: tf.TensorShape): - """ - Build shared token embedding layer Shared weights logic adapted from - https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 - """ - - weight = self._init_weight(self.num_positions, self.embedding_dim) - - self.weight = self.add_weight( - name="embeddings", - shape=[self.num_positions, self.embedding_dim], - ) - weight = tf.cast(weight, dtype=self.weight.dtype) - - self.weight.assign(weight) - - super().build(input_shape) - - @staticmethod - def _init_weight(n_pos: int, dim: int): - """ - Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in - the 2nd half of the vector. [dim // 2:] - """ - position_enc = np.array( - [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] - ) - table = np.zeros_like(position_enc) - # index 0 is all zero - table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) - table[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) - # convert to tensor - table = tf.convert_to_tensor(table) - tf.stop_gradient(table) - return table - - def call( - self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None - ): - """Input is expected to be of size [bsz x seqlen].""" - if position_ids is None: - seq_len = input_shape[1] - position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") - return tf.gather(self.weight, position_ids) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus -class TFPegasusAttention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartEncoderLayer with MBart->Pegasus -class TFPegasusEncoderLayer(keras.layers.Layer): - def __init__(self, config: PegasusConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFPegasusAttention( - self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" - ) - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - layer_head_mask: tf.Tensor, - training: bool | None = False, - ): - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* - attention_mask (`tf.Tensor`): attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - *(encoder_attention_heads,)* - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask - ) - - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) - - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - return hidden_states, self_attn_weights - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.encoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer with MBart->Pegasus -class TFPegasusDecoderLayer(keras.layers.Layer): - def __init__(self, config: PegasusConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFPegasusAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - name="self_attn", - is_decoder=True, - ) - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.encoder_attn = TFPegasusAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - name="encoder_attn", - is_decoder=True, - ) - self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") - self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - cross_attn_layer_head_mask: tf.Tensor | None = None, - past_key_value: tuple[tf.Tensor] | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* - attention_mask (`tf.Tensor`): attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - encoder_hidden_states (`tf.Tensor`): - cross attention input to the layer of shape *(batch, seq_len, embed_dim)* - encoder_attention_mask (`tf.Tensor`): encoder attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - *(decoder_attention_heads,)* - cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. - *(decoder_attention_heads,)* - past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - return ( - hidden_states, - self_attn_weights, - cross_attn_weights, - present_key_value, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "encoder_attn", None) is not None: - with tf.name_scope(self.encoder_attn.name): - self.encoder_attn.build(None) - if getattr(self, "encoder_attn_layer_norm", None) is not None: - with tf.name_scope(self.encoder_attn_layer_norm.name): - self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.decoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFPegasusPreTrainedModel(TFPreTrainedModel): - config_class = PegasusConfig - base_model_prefix = "model" - - -PEGASUS_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`PegasusConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -PEGASUS_GENERATION_EXAMPLE = r""" - Summarization example: - - ```python - >>> from transformers import AutoTokenizer, TFPegasusForConditionalGeneration - - >>> model = TFPegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum") - >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum") - - >>> ARTICLE_TO_SUMMARIZE = ( - ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " - ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " - ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." - ... ) - >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="tf") - - >>> # Generate Summary - >>> summary_ids = model.generate(input_ids) - >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) - ``` -""" - -PEGASUS_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If - `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. - decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tf.FloatTensor`, *optional*): - hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - of shape `(batch_size, sequence_length, hidden_size)` is a sequence of - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation output_attentions (`bool`, - *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` - under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the - value in the config will be used instead. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@keras_serializable -class TFPegasusEncoder(keras.layers.Layer): - config_class = PegasusConfig - """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`TFPegasusEncoderLayer`]. - - Args: - config: PegasusConfig - """ - - def __init__(self, config: PegasusConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.dropout = keras.layers.Dropout(config.dropout) - self.layerdrop = config.encoder_layerdrop - self.padding_idx = config.pad_token_id - self.max_source_positions = config.max_position_embeddings - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - - self.embed_tokens = embed_tokens - self.embed_positions = TFPegasusSinusoidalPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.layers = [TFPegasusEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ): - """ - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value - in the config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. This argument can be used only in eager mode, in graph mode the value in the config - will be used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(input_shape) - hidden_states = inputs_embeds + embed_pos - hidden_states = self.dropout(hidden_states, training=training) - - # check attention mask and invert - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask) - else: - attention_mask = None - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - tf.debugging.assert_equal( - shape_list(head_mask)[0], - len(self.layers), - message=( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(head_mask)[0]}." - ), - ) - - # encoder layers - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): # skip the layer - continue - - hidden_states, attn = encoder_layer( - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - ) - - if output_attentions: - all_attentions += (attn,) - - hidden_states = self.layer_norm(hidden_states) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFPegasusDecoder(keras.layers.Layer): - config_class = PegasusConfig - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFPegasusDecoderLayer`] - - Args: - config: PegasusConfig - embed_tokens: output embedding - """ - - def __init__(self, config: PegasusConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): - super().__init__(**kwargs) - self.config = config - self.padding_idx = config.pad_token_id - self.embed_tokens = embed_tokens - self.layerdrop = config.decoder_layerdrop - self.embed_positions = TFPegasusSinusoidalPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - name="embed_positions", - ) - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - self.layers = [TFPegasusDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - - self.dropout = keras.layers.Dropout(config.dropout) - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ): - r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up - decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value - in the config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. This argument can be used only in eager mode, in graph mode the value in the config - will be used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 - - # embed positions - if position_ids is None: - positions = self.embed_positions(input_shape, past_key_values_length) - else: - positions = self.embed_positions(input_shape, position_ids=position_ids) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - hidden_states = inputs_embeds - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) - else: - combined_attention_mask = _expand_mask( - tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] - ) - - if attention_mask is not None: - combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) - - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) - - hidden_states = self.dropout(hidden_states + positions, training=training) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None - present_key_values = () if use_cache else None - - # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired - for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: - if attn_mask is not None: - tf.debugging.assert_equal( - shape_list(attn_mask)[0], - len(self.layers), - message=( - f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(attn_mask)[0]}." - ), - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - dropout_probability = random.uniform(0, 1) - - if training and (dropout_probability < self.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - past_key_value=past_key_value, - ) - - if use_cache: - present_key_values += (present_key_value,) - - if output_attentions: - all_self_attns += (layer_self_attn,) - - if encoder_hidden_states is not None: - all_cross_attns += (layer_cross_attn,) - - hidden_states = self.layer_norm(hidden_states) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns - else: - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFPegasusMainLayer(keras.layers.Layer): - config_class = PegasusConfig - - def __init__(self, config: PegasusConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.shared = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.d_model, - embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), - name="model.shared", - ) - # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) - self.shared.load_weight_prefix = "model.shared" - - self.encoder = TFPegasusEncoder(config, self.shared, name="encoder") - self.decoder = TFPegasusDecoder(config, self.shared, name="decoder") - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.embed_tokens = self.shared - self.decoder.embed_tokens = self.shared - - @unpack_inputs - def call( - self, - input_ids: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - decoder_input_ids: tf.Tensor | None = None, - decoder_attention_mask: tf.Tensor | None = None, - decoder_position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - decoder_head_mask: tf.Tensor | None = None, - cross_attn_head_mask: tf.Tensor | None = None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[tf.Tensor]] | None = None, - inputs_embeds: tf.Tensor | None = None, - decoder_inputs_embeds: tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs, - ): - if decoder_input_ids is None and decoder_inputs_embeds is None: - use_cache = False - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): - encoder_outputs = TFBaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not return_dict and not isinstance(encoder_outputs, tuple): - encoder_outputs = encoder_outputs.to_tuple() - - decoder_outputs = self.decoder( - decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return TFSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - # The shared/tied weights expect to be in the model base namespace - # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than - # the current one. - with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): - self.shared.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - "The bare PEGASUS Model outputting raw hidden-states without any specific head on top.", - PEGASUS_START_DOCSTRING, -) -class TFPegasusModel(TFPegasusPreTrainedModel): - def __init__(self, config: PegasusConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.model = TFPegasusMainLayer(config, name="model") - - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSeq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, - ) -> TFSeq2SeqModelOutput | tuple[tf.Tensor]: - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqModelOutput( - last_hidden_state=output.last_hidden_state, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer -class BiasLayer(keras.layers.Layer): - """ - Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, - so all weights have to be registered in a layer. - """ - - def __init__(self, shape, initializer, trainable, name, **kwargs): - super().__init__(name=name, **kwargs) - # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of - # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: - # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 - self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) - - def call(self, x): - return x + self.bias - - -@add_start_docstrings( - "The PEGASUS Model with a language modeling head. Can be used for summarization.", - PEGASUS_START_DOCSTRING, -) -class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLanguageModelingLoss): - _keys_to_ignore_on_load_unexpected = [ - r"model.encoder.embed_tokens.weight", - r"model.decoder.embed_tokens.weight", - ] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.model = TFPegasusMainLayer(config, name="model") - self.use_cache = config.use_cache - # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False - ) - - def get_decoder(self): - return self.model.decoder - - def get_encoder(self): - return self.model.encoder - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - def get_bias(self): - return {"final_logits_bias": self.bias_layer.bias} - - def set_bias(self, value): - # Replaces the existing layers containing bias for correct (de)serialization. - vocab_size = value["final_logits_bias"].shape[-1] - self.bias_layer = BiasLayer( - name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False - ) - self.bias_layer.bias.assign(value["final_logits_bias"]) - - @unpack_inputs - @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFSeq2SeqLMOutput | tuple[tf.Tensor]: - """ - labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - """ - - if labels is not None: - labels = tf.where( - labels == self.config.pad_token_id, - tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), - labels, - ) - use_cache = False - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) - lm_logits = self.bias_layer(lm_logits) - masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - return TFSeq2SeqLMOutput( - loss=masked_lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, # index 1 of d outputs - decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs - decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs - cross_attentions=outputs.cross_attentions, # index 4 of d outputs - encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs - encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out - encoder_attentions=outputs.encoder_attentions, # 2 of e out - ) - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - if decoder_attention_mask is not None: # xla - decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] - else: # no xla + no past_key_values - decoder_position_ids = tf.range(decoder_input_ids.shape[1]) - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "decoder_position_ids": decoder_position_ids, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - - def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "bias_layer", None) is not None: - with tf.name_scope(self.bias_layer.name): - self.bias_layer.build(None) - - -__all__ = ["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"] diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py deleted file mode 100644 index 155383772871..000000000000 --- a/src/transformers/models/rag/modeling_tf_rag.py +++ /dev/null @@ -1,1776 +0,0 @@ -# coding=utf-8 -# Copyright 2020, The RAG Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""TFRAG model implementation.""" - -from __future__ import annotations - -import copy -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...configuration_utils import PretrainedConfig -from ...generation import TFLogitsProcessorList -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - keras, - shape_list, - unpack_inputs, -) -from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_rag import RagConfig -from .retrieval_rag import RagRetriever - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "RagConfig" - - -@dataclass -class TFRetrievAugLMMarginOutput(ModelOutput): - """ - Base class for retriever augmented marginalized models outputs. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss. - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head. The score is possibly marginalized over all documents for - each vocabulary token. - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used - (see `past_key_values` input) to speed up sequential decoding. - doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`): - Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and - `question_encoder_last_hidden_state`. - retrieved_doc_embeds (`tf.Tensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*): - Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute - the `doc_scores`. - retrieved_doc_ids (`tf.Tensor` (int32) of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*): - The indexes of the embedded documents retrieved by the retriever. - context_input_ids (`tf.Tensor`(int32) of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): - Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever. - context_attention_mask (`tf.Tensor` (int32) of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): - Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the - retriever. - question_encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden states at the output of the last layer of the question encoder pooled output of the - model. - question_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden states of the question encoder at the output of each layer plus the initial embedding outputs. - question_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the question encoder, after the attention softmax, used to compute the weighted - average in the self-attention heads. - generator_enc_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the generator encoder of the model. - generator_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs. - generator_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted - average in the self-attention heads. - generator_dec_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs. - generator_dec_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted - average in the self-attention heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - doc_scores: tf.Tensor | None = None - retrieved_doc_embeds: tf.Tensor | None = None - retrieved_doc_ids: tf.Tensor | None = None - context_input_ids: tf.Tensor | None = None - context_attention_mask: tf.Tensor | None = None - question_encoder_last_hidden_state: tf.Tensor | None = None - question_enc_hidden_states: tuple[tf.Tensor, ...] | None = None - question_enc_attentions: tuple[tf.Tensor, ...] | None = None - generator_enc_last_hidden_state: tf.Tensor | None = None - generator_enc_hidden_states: tuple[tf.Tensor, ...] | None = None - generator_enc_attentions: tuple[tf.Tensor, ...] | None = None - generator_dec_hidden_states: tuple[tf.Tensor, ...] | None = None - generator_dec_attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFRetrievAugLMOutput(ModelOutput): - """ - Args: - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head. The score is possibly marginalized over all documents for - each vocabulary token. - past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used - (see `past_key_values` input) to speed up sequential decoding. - doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`): - Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and - `question_encoder_last_hidden_state`. - retrieved_doc_embeds (`tf.Tensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*): - Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute - the `doc_scores`. - retrieved_doc_ids (`tf.Tensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*): - The indexes of the embedded documents retrieved by the retriever. - context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): - Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever. - context_attention_mask (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): - Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the - retriever. - question_encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden states at the output of the last layer of the question encoder pooled output of the - model. - question_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden states of the question encoder at the output of each layer plus the initial embedding outputs. - question_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the question encoder, after the attention softmax, used to compute the weighted - average in the self-attention heads. - generator_enc_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the generator encoder of the model. - generator_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs. - generator_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted - average in the self-attention heads. - generator_dec_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs. - generator_dec_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted - average in the self-attention heads. - """ - - logits: tf.Tensor | None = None - past_key_values: list[tf.Tensor] | None = None - doc_scores: tf.Tensor | None = None - retrieved_doc_embeds: tf.Tensor | None = None - retrieved_doc_ids: tf.Tensor | None = None - context_input_ids: tf.Tensor | None = None - context_attention_mask: tf.Tensor | None = None - question_encoder_last_hidden_state: tf.Tensor | None = None - question_enc_hidden_states: tuple[tf.Tensor, ...] | None = None - question_enc_attentions: tuple[tf.Tensor, ...] | None = None - generator_enc_last_hidden_state: tf.Tensor | None = None - generator_enc_hidden_states: tuple[tf.Tensor, ...] | None = None - generator_enc_attentions: tuple[tf.Tensor, ...] | None = None - generator_dec_hidden_states: tuple[tf.Tensor, ...] | None = None - generator_dec_attentions: tuple[tf.Tensor, ...] | None = None - - -class TFRagPreTrainedModel(TFPreTrainedModel): - r""" - RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP - Tasks](https://huggingface.co/papers/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al. - - RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a - generator, the encoder and generator are trainable while the retriever is just an indexed dataset. - - """ - - config_class = RagConfig - base_model_prefix = "rag" - _keys_to_ignore_on_load_missing = [r"position_ids"] - - @classmethod - def from_pretrained_question_encoder_generator( - cls, - question_encoder_pretrained_model_name_or_path: str | None = None, - generator_pretrained_model_name_or_path: str | None = None, - retriever: RagRetriever = None, - *model_args, - **kwargs, - ) -> TFPreTrainedModel: - r""" - Instantiates an question encoder and a generator from one or two base classes of the library from pretrained - model checkpoints. - - Params: - question_encoder_pretrained_model_name_or_path (`str`, *optional*): - Information necessary to initiate the question encoder. Can be either: - - - A string with the *shortcut name* of a pretrained model to load from cache or download, e.g., - `google-bert/bert-base-uncased`. - - A string with the *identifier name* of a pretrained model that was user-uploaded to our S3, e.g., - `dbmdz/bert-base-german-cased`. - - A path to a *directory* containing model weights saved using - [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case, - `question_encoder_from_pt` should be set to `True`. - - generator_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): - Information necessary to initiate the generator. Can be either: - - - A string with the *shortcut name* of a pretrained model to load from cache or download, e.g., - `google-t5/t5-small`. - - A string with the *identifier name* of a pretrained model that was user-uploaded to our S3, e.g., - `facebook/bart-base`. - - A path to a *directory* containing model weights saved using - [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case, - `generator_from_pt` should be set to `True`. - - model_args (remaining positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - retriever ([`RagRetriever`], *optional*): - The retriever to use. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). - - - To update the question_encoder configuration, use the prefix *question_encoder_* for each - configuration parameter. - - To update the generator configuration, use the prefix *generator_* for each configuration parameter. - - To update the parent model configuration, do not use a prefix for each configuration parameter. - - Behaves differently depending on whether a `config` is provided or automatically loaded. - - Example: - - ```python - >>> from transformers import RagRetriever, TFRagModel - - >>> # initialize a RAG from two pretrained models. - >>> model = TFRagModel.from_pretrained_question_encoder_generator( - ... "facebook/dpr-question_encoder-single-nq-base", "google-t5/t5-small" - ... ) - >>> # alternatively, initialize from pytorch pretrained models can also be done - >>> model = TFRagModel.from_pretrained_question_encoder_generator( - ... "facebook/dpr-question_encoder-single-nq-base", - ... "facebook/bart-base", - ... generator_from_pt=True, - ... question_encoder_from_pt=True, - ... ) - - >>> # saving model after fine-tuning - >>> model.save_pretrained("./rag") - - >>> # load retriever - >>> retriever = RagRetriever.from_pretrained( - ... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True - ... ) - >>> # load fine-tuned model with retriever - >>> model = TFRagModel.from_pretrained("./rag", retriever=retriever) - ```""" - - kwargs_question_encoder = { - argument[len("question_encoder_") :]: value - for argument, value in kwargs.items() - if argument.startswith("question_encoder_") - } - - kwargs_generator = { - argument[len("generator_") :]: value - for argument, value in kwargs.items() - if argument.startswith("generator_") - } - - # remove question_encoder, generator kwargs from kwargs - for key in kwargs_question_encoder: - del kwargs["question_encoder_" + key] - for key in kwargs_generator: - del kwargs["generator_" + key] - - # Load and initialize the question_encoder and generator - # The distinction between question_encoder and generator at the model level is made - # by the value of the flag `is_generator` that we need to set correctly. - question_encoder = kwargs_question_encoder.pop("model", None) - if question_encoder is None: - assert question_encoder_pretrained_model_name_or_path is not None, ( - "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to" - " be defined" - ) - - from ..auto.modeling_tf_auto import TFAutoModel - - if "config" not in kwargs_question_encoder: - from ..auto.configuration_auto import AutoConfig - - question_encoder_config = AutoConfig.from_pretrained(question_encoder_pretrained_model_name_or_path) - kwargs_question_encoder["config"] = question_encoder_config - - question_encoder = TFAutoModel.from_pretrained( - question_encoder_pretrained_model_name_or_path, - name="question_encoder", - load_weight_prefix=cls.load_weight_prefix, - *model_args, - **kwargs_question_encoder, - ) - - generator = kwargs_generator.pop("generator", None) - if generator is None: - assert generator_pretrained_model_name_or_path is not None, ( - "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has" - " to be defined" - ) - - from ..auto.modeling_tf_auto import TFAutoModelForSeq2SeqLM - - if "config" not in kwargs_generator: - from ..auto.configuration_auto import AutoConfig - - generator_config = AutoConfig.from_pretrained(generator_pretrained_model_name_or_path) - kwargs_generator["config"] = generator_config - - generator = TFAutoModelForSeq2SeqLM.from_pretrained( - generator_pretrained_model_name_or_path, - name="generator", - load_weight_prefix=cls.load_weight_prefix, - **kwargs_generator, - ) - - # instantiate config with corresponding kwargs - config = kwargs.get("config") - if config is None: - config = RagConfig.from_question_encoder_generator_configs( - question_encoder.config, generator.config, **kwargs - ) - - return cls(question_encoder=question_encoder, generator=generator, config=config, retriever=retriever) - - -RAG_START_DOCSTRING = r""" - - RAG is a sequence-to-sequence model which encapsulates two core components: a question encoder and a generator. - During a forward pass, we encode the input with the question encoder and pass it to the retriever to extract - relevant context documents. The documents are then prepended to the input. Such contextualized inputs is passed to - the generator. - - The question encoder can be any *autoencoding* model, preferably [`TFDPRQuestionEncoder`], and the generator can be - any *seq2seq* model, preferably [`TFBartForConditionalGeneration`]. - - The model can be initialized with a [`RagRetriever`] for end-to-end generation or used in combination with the - outputs of a retriever in multiple steps---see examples for more details. The model is compatible any - *autoencoding* model as the `question_encoder` and any *seq2seq* model with language model head as the `generator`. - It has been tested with [`TFDPRQuestionEncoder`] as the `question_encoder` and [`TFBartForConditionalGeneration`] - as the `generator`. - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Tensorflow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) - subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to - general usage and behavior. - - The model is in a developing state as it is now fully supports in eager-mode only, and may not be exported in - SavedModel format. - - Args: - config ([`RagConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. - question_encoder ([`TFPreTrainedModel`]): - An encoder model compatible with the faiss index encapsulated by the `retriever`. - generator ([`TFPreTrainedModel`]): - A seq2seq model used as the generator in the RAG architecture. - retriever ([`RagRetriever`]): - A retriever class encapsulating a faiss index queried to obtain context documents for current inputs. -""" - - -RAG_FORWARD_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies - which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to - obtain the indices. - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*) - Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`, - *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs * - sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the - generator's encoder. - - Used by the ([`TFRagModel`]) model during decoding. - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Provide for generation tasks. `None` by default, construct as per instructions for the generator model - you're using with your RAG instance. - decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - past_key_values (`tuple(tuple(tf.Tensor))`): - Tuple consists of two elements: `encoder_outputs` of the RAG model (see `encoder_outputs`) and - `past_key_values` of the underlying generator. Can be used to speed up decoding. `past_key_values` are used - in the ([`RagTokenForGeneration`]) model during decoding. - doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`): - Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and - `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores` - has to be provided to the forward pass. `doc_scores` can be computed via - `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information. - context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): - Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the - retriever. - - If the model has is not initialized with a `retriever` ``context_input_ids` has to be provided to the - forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. context_attention_mask - (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when - *output_retrieved=True*): Attention mask post-processed from the retrieved documents and the question - encoder `input_ids` by the retriever. - - If the model has is not initialized with a `retriever` `context_attention_mask` has to be provided to the - forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`]. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - output_retrieved(`bool`, *optional*): - Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and - `context_attention_mask`. See returned tensors for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`TFRetrievAugLMOutput`] instead of a plain tuple. - n_docs (`int`, *optional*, defaults to `config.n_docs``) - Number of documents to retrieve and/or number of documents for which to generate an answer. -""" - - -@add_start_docstrings_to_model_forward(RAG_START_DOCSTRING) -class TFRagModel(TFRagPreTrainedModel): - load_weight_prefix = "tf_rag_model_1" - - def __init__( - self, - config: PretrainedConfig | None = None, - question_encoder: TFPreTrainedModel | None = None, - generator: TFPreTrainedModel | None = None, - retriever: RagRetriever | None = None, - load_weight_prefix: str | None = None, - **kwargs, - ): - assert config is not None or (question_encoder is not None and generator is not None), ( - "Either a configuration or an question_encoder and a generator has to be provided." - ) - - if config is None: - config = RagConfig.from_question_encoder_generator_configs( - question_encoder.config, generator.config, **kwargs - ) - else: - assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}" - super().__init__(config, **kwargs) - - if question_encoder is None: - from ..auto.modeling_tf_auto import TFAutoModel - - question_encoder = TFAutoModel.from_config(config.question_encoder, name="question_encoder") - - if generator is None: - from ..auto.modeling_tf_auto import TFAutoModelForSeq2SeqLM - - load_weight_prefix = load_weight_prefix if load_weight_prefix is not None else self.load_weight_prefix - generator = TFAutoModelForSeq2SeqLM.from_config( - config.generator, name="generator", load_weight_prefix=load_weight_prefix + "/generator" - ) - - self.retriever = retriever - if self.retriever is not None: - assert isinstance(retriever, RagRetriever), ( - f"`self.retriever` is of type {type(self.retriever)}, but should be of type `RagRetriever`" - ) - self.retriever = retriever - - self.question_encoder = question_encoder - self.generator = generator - - def set_retriever(self, retriever: RagRetriever): - self.retriever = retriever - - @unpack_inputs - @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFRetrievAugLMOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - doc_scores: np.ndarray | tf.Tensor | None = None, - context_input_ids: np.ndarray | tf.Tensor | None = None, - context_attention_mask: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - output_retrieved: bool | None = None, - n_docs: int | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, - ) -> TFRetrievAugLMOutput: - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, RagRetriever, TFRagModel - >>> import torch - from ...utils.deprecation import deprecate_kwarg - from ...utils.deprecation import deprecate_kwarg - from ...utils.deprecation import deprecate_kwarg - from ...utils.deprecation import deprecate_kwarg - - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base") - >>> retriever = RagRetriever.from_pretrained( - ... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True - ... ) - >>> # initialize with RagRetriever to do everything in one forward call - >>> model = TFRagModel.from_pretrained("facebook/rag-token-base", retriever=retriever, from_pt=True) - - >>> input_dict = tokenizer.prepare_seq2seq_batch( - ... "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf" - ... ) - >>> input_ids = input_dict["input_ids"] - >>> outputs = model(input_ids) - ```""" - assert "decoder_cached_states" not in kwargs, ( - "Please use past_key_values to cache intermediate outputs" - ) # from modeling_tf_bart.py - - # aliasing to minimize code changing - n_docs = n_docs if n_docs is not None else self.config.n_docs - - # whether retriever has to be used - has_to_retrieve = ( - self.retriever is not None - and (context_input_ids is None or context_attention_mask is None or doc_scores is None) - and encoder_outputs is None - ) - - # encoder_outputs are pre-computed during RAG-token generation - if encoder_outputs is None: - if has_to_retrieve: - question_enc_outputs = self.question_encoder( - input_ids, attention_mask=attention_mask, return_dict=True, training=training - ) - # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/dpr/modeling_tf_dpr.py#L91 - question_encoder_last_hidden_state = question_enc_outputs[ - 0 - ] # hidden states of question encoder => pooler_output - - retriever_outputs = self.retriever( - input_ids, - question_encoder_last_hidden_state.numpy(), - prefix=self.generator.config.prefix, - n_docs=n_docs, - return_tensors="tf", - ) - context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = ( - retriever_outputs["context_input_ids"], - retriever_outputs["context_attention_mask"], - retriever_outputs["retrieved_doc_embeds"], - retriever_outputs["doc_ids"], - ) - - context_input_ids = tf.cast(context_input_ids, tf.int32) - context_attention_mask = tf.cast(context_attention_mask, tf.int32) - retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) - retrieved_doc_ids = tf.cast(retrieved_doc_ids, tf.int32) - - # compute doc_scores - doc_scores = tf.squeeze( - tf.matmul( - tf.expand_dims(question_encoder_last_hidden_state, axis=1), - retrieved_doc_embeds, - transpose_b=True, - ), - axis=1, - ) - - else: - assert context_input_ids is not None, ( - "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can" - " set a retriever using the `set_retriever(...)` function." - ) - assert context_attention_mask is not None, ( - "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you" - " can set a retriever using the `set_retriever(...)` function." - ) - assert doc_scores is not None, ( - "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a" - " retriever using the `set_retriever(...)` function." - ) - - assert doc_scores is not None, ( - "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function." - ) - - assert (doc_scores.shape[1] % n_docs) == 0, ( - f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is" - f" {context_input_ids.shape[0]}." - ) - - # Decoder input without context documents - if decoder_input_ids is not None: - decoder_input_ids = tf.repeat(decoder_input_ids, n_docs, axis=0) - - if decoder_attention_mask is not None: - decoder_attention_mask = tf.repeat(decoder_attention_mask, n_docs, axis=0) - - gen_outputs = self.generator( - context_input_ids, - attention_mask=context_attention_mask, - encoder_outputs=encoder_outputs, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - return_dict=True, - training=training, - ) - - if not has_to_retrieve: - question_encoder_last_hidden_state = None - question_enc_hidden_states = None - question_enc_attentions = None - retrieved_doc_embeds = None - retrieved_doc_ids = None - else: - question_enc_hidden_states = question_enc_outputs.hidden_states - question_enc_attentions = question_enc_outputs.attentions - - if not has_to_retrieve or not output_retrieved: - # don't output retrieved docs - context_input_ids = (None,) - context_attention_mask = None - retrieved_doc_embeds = None - retrieved_doc_ids = None - - return TFRetrievAugLMOutput( - logits=gen_outputs.logits, - doc_scores=doc_scores, - past_key_values=gen_outputs.past_key_values, - context_input_ids=context_input_ids, - context_attention_mask=context_attention_mask, - retrieved_doc_embeds=retrieved_doc_embeds, - retrieved_doc_ids=retrieved_doc_ids, - question_encoder_last_hidden_state=question_encoder_last_hidden_state, - question_enc_hidden_states=question_enc_hidden_states, - question_enc_attentions=question_enc_attentions, - generator_enc_last_hidden_state=gen_outputs.encoder_last_hidden_state, - generator_enc_hidden_states=gen_outputs.encoder_hidden_states, - generator_enc_attentions=gen_outputs.encoder_attentions, - generator_dec_hidden_states=gen_outputs.decoder_hidden_states, - generator_dec_attentions=gen_outputs.decoder_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - with tf.name_scope(self.generator.name): - self.generator.build(None) - with tf.name_scope(self.question_encoder.name): - self.question_encoder.build(None) - - -@add_start_docstrings_to_model_forward( - """ - A TF RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass. - """, - RAG_START_DOCSTRING, -) -class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss): - load_weight_prefix = "tf_rag_token_for_generation_1/rag" - - def __init__( - self, - config: PretrainedConfig | None = None, - question_encoder: TFPreTrainedModel | None = None, - generator: TFPreTrainedModel | None = None, - retriever: RagRetriever | None = None, - **kwargs, - ): - assert config is not None or (question_encoder is not None and generator is not None), ( - "Either a configuration or an encoder and a generator has to be provided." - ) - - if config is None: - config = RagConfig.from_question_encoder_generator_configs( - question_encoder.config, generator.config, **kwargs - ) - - super().__init__(config) - - # instantiate model - self.rag = TFRagModel( - config=config, - question_encoder=question_encoder, - generator=generator, - retriever=retriever, - load_weight_prefix=self.load_weight_prefix, - name="rag", - ) - - def set_retriever(self, retriever: RagRetriever): - self.rag.retriever = retriever - - # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_bart.py - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - use_cache=None, - encoder_outputs=None, - doc_scores=None, - n_docs=None, - **kwargs, - ): - if past_key_values is not None: - # if past is defined use only last decoder_input_ids - decoder_input_ids = decoder_input_ids[:, -1:] - - return { - "input_ids": None, - "encoder_outputs": encoder_outputs, - "doc_scores": doc_scores, - "context_attention_mask": attention_mask, - "decoder_input_ids": decoder_input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "do_marginalize": True, - "n_docs": n_docs, - } - - @property - def retriever(self): - return self.rag.retriever - - @property - def generator(self): - return self.rag.generator - - @property - def question_encoder(self): - return self.rag.question_encoder - - @staticmethod - def _gather_beams(nested, beam_indices, batch_axis=0): - """ - RAG-specific `_gather_beams`: gathers the beam slices indexed by beam_indices into new beam array. If the - nested tensor has a shape mismatch with the beam indices, then it means it is the cache. In that case, isolates - and takes care of the extra dimension for ndocs. - """ - - def gather_fn(tensor): - is_rag_cache = tensor.shape[0] != beam_indices.shape[0] - if is_rag_cache: - n_docs = tensor.shape[0] // beam_indices.shape[0] - batch_size = beam_indices.shape[0] - # reshapes into (batch size, num beams, n_docs, ...), the cache format expected by RAG - tensor = tf.reshape(tensor, (batch_size, -1, n_docs, *tensor.shape[2:])) - - gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1) - - if is_rag_cache: - # reshapes back into the shape expected by beam search - gathered_tensor = tf.reshape(gathered_tensor, (batch_size * n_docs, -1, *gathered_tensor.shape[3:])) - - return gathered_tensor - - return tf.nest.map_structure(gather_fn, nested) - - def marginalize(self, seq_logits, doc_scores, n_docs=None): - n_docs = n_docs if n_docs is not None else self.config.n_docs - - # RAG-token marginalization - seq_logprobs = tf.nn.log_softmax(seq_logits, axis=-1) - seq_logprobs = tf.reshape(seq_logprobs, [seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.shape[-1]]) - doc_logprobs = tf.nn.log_softmax(doc_scores, axis=1) - doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1) - doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1) # twice - log_prob_sum = seq_logprobs + doc_logprobs - return tf.reduce_logsumexp(log_prob_sum, axis=1) - - @unpack_inputs - @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - doc_scores: np.ndarray | tf.Tensor | None = None, - context_input_ids: np.ndarray | tf.Tensor | None = None, - context_attention_mask: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - output_retrieved: bool | None = None, - n_docs: int | None = None, - do_marginalize: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - reduce_loss: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, # needs kwargs for generation - ) -> TFRetrievAugLMMarginOutput: - r""" - do_marginalize (`bool`, *optional*): - If `True`, the logits are marginalized over all documents by making use of - `torch.nn.functional.log_softmax`. - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss according to Rag-Token model formulation See - https://huggingface.co/papers/2005.11401 Section 2.1 for details about Rag-Token formulation. Indices should be - in `[0, ..., config.vocab_size - 1]`. - reduce_loss (`bool`, *optional*): - Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `tf.Tensor.sum` - operation. - kwargs (`dict[str, any]`, *optional*, defaults to `{}`): - Legacy dictionary, which is required so that model can use *generate()* function. - - Returns: - - Example: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, RagRetriever, TFRagTokenForGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-nq") - >>> retriever = RagRetriever.from_pretrained( - ... "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True - ... ) - >>> # initialize with RagRetriever to do everything in one forward call - >>> model = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever, from_pt=True) - - >>> input_dict = tokenizer.prepare_seq2seq_batch( - ... "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf" - ... ) - >>> outputs = model(input_dict, output_retrieved=True) - - >>> # or use retriever separately - >>> # 1. Encode - >>> input_ids = input_dict["input_ids"] - >>> question_hidden_states = model.question_encoder(input_ids)[0] - >>> # 2. Retrieve - >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf") - >>> doc_scores = tf.squeeze( - ... tf.matmul( - ... tf.expand_dims(question_hidden_states, axis=1), docs_dict["retrieved_doc_embeds"], transpose_b=True - ... ), - ... axis=1, - ... ) - >>> # 3. Forward to generator - >>> outputs = model( - ... inputs=None, - ... context_input_ids=docs_dict["context_input_ids"], - ... context_attention_mask=docs_dict["context_attention_mask"], - ... doc_scores=doc_scores, - ... decoder_input_ids=input_dict["labels"], - ... ) - - >>> # or directly generate - >>> generated = model.generate( - ... context_input_ids=docs_dict["context_input_ids"], - ... context_attention_mask=docs_dict["context_attention_mask"], - ... doc_scores=doc_scores, - ... ) - >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True) - ```""" - - assert "decoder_cached_states" not in kwargs, ( - "Please use past_key_values to cache intermediate outputs" - ) # from modeling_tf_bart.py - - do_marginalize = do_marginalize if do_marginalize else self.config.do_marginalize - reduce_loss = reduce_loss if reduce_loss else self.config.reduce_loss - - if labels is not None: - if decoder_input_ids is None: - decoder_input_ids = labels - use_cache = False - - outputs = self.rag( - input_ids, - attention_mask=attention_mask, - encoder_outputs=encoder_outputs, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - context_input_ids=context_input_ids, - context_attention_mask=context_attention_mask, - doc_scores=doc_scores, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_retrieved=output_retrieved, - n_docs=n_docs, - training=training, - ) - - loss = None - logits = outputs.logits - if labels is not None: - assert decoder_input_ids is not None - loss = self.get_nll( - outputs.logits, - outputs.doc_scores, - labels, - reduce_loss=reduce_loss, - epsilon=self.config.label_smoothing, - n_docs=n_docs, - ) - - if do_marginalize: - logits = self.marginalize(logits, outputs.doc_scores, n_docs) - - return TFRetrievAugLMMarginOutput( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - doc_scores=outputs.doc_scores, - context_input_ids=outputs.context_input_ids, - context_attention_mask=outputs.context_attention_mask, - retrieved_doc_embeds=outputs.retrieved_doc_embeds, - retrieved_doc_ids=outputs.retrieved_doc_ids, - question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state, - question_enc_hidden_states=outputs.question_enc_hidden_states, - question_enc_attentions=outputs.question_enc_attentions, - generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state, - generator_enc_hidden_states=outputs.generator_enc_hidden_states, - generator_enc_attentions=outputs.generator_enc_attentions, - generator_dec_hidden_states=outputs.generator_dec_hidden_states, - generator_dec_attentions=outputs.generator_dec_attentions, - ) - - def generate( - self, - input_ids: TFModelInputType | None = None, - attention_mask: tf.Tensor | None = None, - context_input_ids=None, - context_attention_mask=None, - doc_scores=None, - n_docs=None, - generation_config=None, - logits_processor=TFLogitsProcessorList(), - **kwargs, - ): - """ - Implements TFRAG token decoding. - - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - The sequence used as a prompt for the generation. If `input_ids` is not passed, then - `context_input_ids` has to be provided. - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): - Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the - retriever. - - If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the - forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. - context_attention_mask (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): - Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the - retriever. - - If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the - forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. - doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`): - Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and - `question_encoder_last_hidden_state`. - - If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the - forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. - n_docs (`int`, *optional*, defaults to `config.n_docs`) - Number of documents to retrieve and/or number of documents for which to generate an answer. - generation_config (`~generation.GenerationConfig`, *optional*): - The generation configuration to be used as base parametrization for the generation call. `**kwargs` - passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which had the following loading - priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model - configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s - default values, whose documentation should be checked to parameterize generation. - logits_processor (`TFLogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and a - model's config. If a logit processor is passed that is already created with the arguments or a model's - config an error is thrown. - kwargs (`dict[str, Any]`, *optional*): - Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be - forwarded to the `forward` function of the model. - - Return: - `tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The - second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early - due to the `eos_token_id`. - """ - # Handle `generation_config` and kwargs that might update it - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - - # set default parameters - n_docs = n_docs if n_docs is not None else self.config.n_docs - - # retrieve docs - if self.retriever is not None and context_input_ids is None: - question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] - out = self.retriever( - input_ids, - question_hidden_states.numpy().astype(np.float32), - prefix=self.generator.config.prefix, - n_docs=n_docs, - return_tensors="tf", - ) - context_input_ids, context_attention_mask, retrieved_doc_embeds = ( - out["context_input_ids"], - out["context_attention_mask"], - out["retrieved_doc_embeds"], - ) - - context_input_ids = tf.cast(context_input_ids, tf.int32) - context_attention_mask = tf.cast(context_attention_mask, tf.int32) - retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32) - - # compute doc_scores - doc_scores = tf.matmul( - tf.expand_dims(question_hidden_states, axis=1), retrieved_doc_embeds, transpose_b=True - ) - doc_scores = tf.squeeze(doc_scores, axis=1) - - assert (context_input_ids.shape[0] % n_docs) == 0, ( - f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is" - f" {context_input_ids.shape[0]}." - ) - - batch_size = context_input_ids.shape[0] // n_docs - - encoder = self.rag.generator.get_encoder() - encoder_outputs = encoder( - input_ids=context_input_ids, - attention_mask=context_attention_mask, - output_attentions=generation_config.output_attentions, - output_hidden_states=generation_config.output_hidden_states, - return_dict=True, - ) - - decoder_input_ids = tf.fill( - (batch_size * generation_config.num_beams, 1), - tf.cast(generation_config.decoder_start_token_id, tf.int32), - ) - last_hidden_state = encoder_outputs["last_hidden_state"] - - def extend_enc_output(tensor, num_beams=None): - """ - Broadcast tensor with `num_beams` replica, with correct order Input: tensor of shape (batch_size*n_docs , - d) Output: tensor of shape (batch_size*num_beams*n_docs , d) - """ - - # expand batch_size & num_beam dimensions - d_shape_list = tensor.shape[1:] - - # split n_docs dimensions - new_shape = (batch_size, 1, n_docs) + d_shape_list - tensor = tf.reshape(tensor, new_shape) - - # repeat same last hidden states over `num_beams` dimension - new_shape = (batch_size, num_beams, n_docs) + d_shape_list - tensor = tf.broadcast_to(tensor, new_shape) - - # merge `batch_size`, `num_beams`, `num_docs` dims again - new_shape = (batch_size * num_beams * n_docs,) + d_shape_list - return tf.reshape(tensor, new_shape) - - # correctly extend last_hidden_state and attention mask - context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams) - encoder_outputs["last_hidden_state"] = extend_enc_output( - last_hidden_state, num_beams=generation_config.num_beams - ) - - doc_scores = tf.repeat(doc_scores, generation_config.num_beams, axis=0) - - # define start_len & additional parameters - model_kwargs["doc_scores"] = doc_scores - model_kwargs["encoder_outputs"] = encoder_outputs - model_kwargs["attention_mask"] = context_attention_mask - model_kwargs["n_docs"] = n_docs - - pre_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=tf.shape(decoder_input_ids)[-1], - logits_processor=logits_processor, - ) - - if generation_config.num_beams == 1: - return self.greedy_search( - input_ids=decoder_input_ids, - max_length=generation_config.max_length, - pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, - logits_processor=pre_processor, - output_attentions=generation_config.output_attentions, - output_hidden_states=generation_config.output_hidden_states, - output_scores=generation_config.output_scores, - return_dict_in_generate=generation_config.return_dict_in_generate, - **model_kwargs, - ) - elif generation_config.num_beams > 1: - if generation_config.num_beams < generation_config.num_return_sequences: - raise ValueError( - "Beam search decoding cannot return more sequences than it has beams. Please set num_beams >=" - f" num_return_sequences, got {generation_config.num_beams} and" - f" {generation_config.num_return_sequences} (respectively)" - ) - - def unflatten_beam_dim(tensor): - """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" - shape = shape_list(tensor) - return tf.reshape(tensor, [-1, generation_config.num_beams] + shape[1:]) - - decoder_input_ids = unflatten_beam_dim(decoder_input_ids) - model_kwargs["attention_mask"] = unflatten_beam_dim(model_kwargs["attention_mask"]) - model_kwargs["encoder_outputs"]["last_hidden_state"] = unflatten_beam_dim( - model_kwargs["encoder_outputs"]["last_hidden_state"] - ) - - return self.beam_search( - input_ids=decoder_input_ids, - max_length=generation_config.max_length, - pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, - logits_processor=pre_processor, - output_attentions=generation_config.output_attentions, - output_hidden_states=generation_config.output_hidden_states, - output_scores=generation_config.output_scores, - return_dict_in_generate=generation_config.return_dict_in_generate, - **model_kwargs, - ) - else: - raise ValueError( - f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}" - ) - - def get_input_embeddings(self): - return self.rag.generator.get_input_embeddings() - - def get_output_embeddings(self): - return self.rag.generator.get_output_embeddings() - - # Adapted from tf_t5's & tf_bart's _shift_right - def shift_tokens_right(self, input_ids, start_token_id=None): - """Shift input ids one token to the right, and pad with start_token_id""" - - if start_token_id is None: - start_token_id = self.generator.config.decoder_start_token_id - assert start_token_id is not None, ( - "self.generator.config.decoder_start_token_id has to be defined. In Rag we commonly use Bart as" - " generator, see Bart docs for more information" - ) - - pad_token_id = self.generator.config.pad_token_id - assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." - - start_tokens = tf.fill((shape_list(input_ids)[0], 1), tf.cast(start_token_id, input_ids.dtype)) - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.fill(shape_list(shifted_input_ids), tf.cast(pad_token_id, input_ids.dtype)), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, shifted_input_ids.dtype)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - # nll stands for 'negative log likelihood' - def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, n_docs=None): - n_docs = n_docs if n_docs is not None else self.config.n_docs - # shift tokens left (from original Pytorch's version) - - target = tf.concat( - [target[:, 1:], tf.fill([target.shape[0], 1], tf.cast(self.config.generator.pad_token_id, target.dtype))], - axis=1, - ) - rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs) - loss = self.hf_compute_loss(target, rag_logprobs, from_logits=True, reduce_loss=reduce_loss) - - return loss - - # Adopted modeling_tf_bart + add smooth_loss to match with pytorch version - def hf_compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False): - """CrossEntropyLoss that ignores pad tokens""" - # Matt: As written, this loss is not XLA-compatible, but it's doing some very weird things - # and I don't feel comfortable converting it. - loss_fn = keras.losses.SparseCategoricalCrossentropy( - from_logits=True, - reduction=keras.losses.Reduction.SUM, - ) - - if from_logits is False: # convert to logits - eps = 1e-9 - y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps) - y_pred = tf.math.log(y_pred) - - logits = y_pred - melted_labels = tf.reshape(labels, (-1,)) - active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id) - - reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss) - labels = tf.boolean_mask(melted_labels, active_loss) - nll_loss = loss_fn(labels, reduced_logits) - - smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1) - smooth_loss = tf.reduce_sum(smooth_loss) # sum and squeeze like torch - eps_i = smooth_epsilon / reduced_logits.shape[-1] - - loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss - - return loss - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "rag", None) is not None: - with tf.name_scope(self.rag.name): - self.rag.build(None) - - -@add_start_docstrings_to_model_forward( - """ - A TF RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass. - """, - RAG_START_DOCSTRING, -) -class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss): - load_weight_prefix = "tf_rag_sequence_for_generation_1/rag" - - def __init__( - self, - config: PretrainedConfig | None = None, - question_encoder: TFPreTrainedModel | None = None, - generator: TFPreTrainedModel | None = None, - retriever: RagRetriever | None = None, - **kwargs, - ): - assert config is not None or (question_encoder is not None and generator is not None), ( - "Either a configuration or an encoder and a generator has to be provided." - ) - - if config is None: - config = RagConfig.from_question_encoder_generator_configs( - question_encoder.config, generator.config, **kwargs - ) - - super().__init__(config) - - # instantiate model - self.rag = TFRagModel( - config=config, - question_encoder=question_encoder, - generator=generator, - retriever=retriever, - load_weight_prefix=self.load_weight_prefix, - name="rag", - ) - - def set_retriever(self, retriever: RagRetriever): - self.rag.retriever = retriever - - @property - def retriever(self): - return self.rag.retriever - - @property - def generator(self): - return self.rag.generator - - @property - def question_encoder(self): - return self.rag.question_encoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - doc_scores: np.ndarray | tf.Tensor | None = None, - context_input_ids: np.ndarray | tf.Tensor | None = None, - context_attention_mask: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - output_retrieved: bool | None = None, - n_docs: int | None = None, - exclude_bos_score: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - reduce_loss: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, # needs kwargs for generation - ) -> tuple[tf.Tensor] | TFRetrievAugLMMarginOutput: - r""" - exclude_bos_score (`bool`, *optional*): - Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing - the loss. - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss according to Rag-Sequence model formulation See - https://huggingface.co/papers/2005.11401 Section 2.1 for details about Rag-Sequence formulation. Indices should - be in `[0, ..., config.vocab_size - 1]`. - reduce_loss (`bool`, *optional*): - Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `tf.Tensor.sum` - operation. - kwargs (`dict[str, any]`, *optional*, defaults to `{}`): - Legacy dictionary, which is required so that model can use *generate()* function. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, RagRetriever, TFRagSequenceForGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq") - >>> retriever = RagRetriever.from_pretrained( - ... "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True - ... ) - >>> # initialize with RagRetriever to do everything in one forward call - >>> model = TFRagSequenceForGeneration.from_pretrained( - ... "facebook/rag-sequence-nq", retriever=retriever, from_pt=True - ... ) - - >>> input_dict = tokenizer.prepare_seq2seq_batch( - ... "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf" - ... ) - >>> outputs = model(input_dict, output_retrieved=True) - - >>> # or use retriever separately - >>> # 1. Encode - >>> input_ids = input_dict["input_ids"] - >>> question_hidden_states = model.question_encoder(input_ids)[0] - >>> # 2. Retrieve - >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf") - >>> doc_scores = tf.squeeze( - ... tf.matmul( - ... tf.expand_dims(question_hidden_states, axis=1), docs_dict["retrieved_doc_embeds"], transpose_b=True - ... ), - ... axis=1, - ... ) - >>> # 3. Forward to generator - >>> outputs = model( - ... inputs=None, - ... context_input_ids=docs_dict["context_input_ids"], - ... context_attention_mask=docs_dict["context_attention_mask"], - ... doc_scores=doc_scores, - ... decoder_input_ids=input_dict["labels"], - ... ) - - >>> # or directly generate - >>> generated = model.generate( - ... context_input_ids=docs_dict["context_input_ids"], - ... context_attention_mask=docs_dict["context_attention_mask"], - ... doc_scores=doc_scores, - ... ) - >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True) - ```""" - - assert "decoder_cached_states" not in kwargs, ( - "Please use past_key_values to cache intermediate outputs" - ) # from modeling_tf_bart.py - - exclude_bos_score = exclude_bos_score if exclude_bos_score else self.config.exclude_bos_score - reduce_loss = reduce_loss if reduce_loss else self.config.reduce_loss - - if labels is not None: - if decoder_input_ids is None: - decoder_input_ids = labels - use_cache = False - - outputs = self.rag( - input_ids, - attention_mask=attention_mask, - encoder_outputs=encoder_outputs, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - context_input_ids=context_input_ids, - context_attention_mask=context_attention_mask, - doc_scores=doc_scores, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_retrieved=output_retrieved, - n_docs=n_docs, - training=training, - ) - - loss = None - if labels is not None: - loss = self.get_nll( - outputs.logits, - outputs.doc_scores, - labels, - reduce_loss=reduce_loss, - epsilon=self.config.label_smoothing, - n_docs=n_docs, - ) - - return TFRetrievAugLMMarginOutput( - loss=loss, - logits=outputs.logits, - doc_scores=outputs.doc_scores, - past_key_values=outputs.past_key_values, - context_input_ids=outputs.context_input_ids, - context_attention_mask=outputs.context_attention_mask, - retrieved_doc_embeds=outputs.retrieved_doc_embeds, - retrieved_doc_ids=outputs.retrieved_doc_ids, - question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state, - question_enc_hidden_states=outputs.question_enc_hidden_states, - question_enc_attentions=outputs.question_enc_attentions, - generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state, - generator_enc_hidden_states=outputs.generator_enc_hidden_states, - generator_enc_attentions=outputs.generator_enc_attentions, - generator_dec_hidden_states=outputs.generator_dec_hidden_states, - generator_dec_attentions=outputs.generator_dec_attentions, - ) - - def get_nll( - self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False, n_docs=None - ): - # shift tokens left - target = tf.concat( - [target[:, 1:], tf.fill([target.shape[0], 1], tf.cast(self.config.generator.pad_token_id, target.dtype))], - axis=1, - ) - - # bos_token_id is None for T5 - bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id - n_docs = n_docs if n_docs is not None else self.config.n_docs - equal_bos_token_id_all = tf.reduce_all(tf.equal(target[:, 0], bos_token_id)) - use_bos = bos_token_id is not None and equal_bos_token_id_all - - def _mask_pads(ll, smooth_obj): - pad_mask = tf.equal(target, tf.cast(self.config.generator.pad_token_id, target.dtype)) - if tf.reduce_any(pad_mask): - ll = tf.where(pad_mask, 0.0, ll) - smooth_obj = tf.where(pad_mask, 0.0, smooth_obj) - return tf.squeeze(ll, axis=-1), tf.squeeze(smooth_obj, axis=-1) - - # seq_logits.shape = (batch*n_docs, tgt_len , vocabs) - seq_logprobs = tf.nn.log_softmax(seq_logits, axis=-1) - seq_logprobs = tf.reshape( - seq_logprobs, (seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.shape[-1]) - ) # (batch_size, n_docs, tgt_len, vocabs) - doc_logprobs = tf.nn.log_softmax(doc_scores, axis=1) - doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1) - doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1) # done twice to get 4-D - - # RAG-sequence marginalization - first_token_scores = seq_logprobs[:, :, :1, :] - second_token_scores = seq_logprobs[:, :, 1:2, :] - remainder = seq_logprobs[:, :, 2:, :] - rag_logprobs = tf.concat([first_token_scores, second_token_scores + doc_logprobs, remainder], axis=2) - - # calculate loss - target = tf.expand_dims(target, axis=1) # n_docs dimension - target = tf.expand_dims(target, axis=-1) # logits dimension - target = tf.repeat(target, n_docs, axis=1) - assert len(target.shape) == len(rag_logprobs.shape) - - # last-axis gathering only - use 2D-reshape-trick for Torch's style nD gathering - def torch_gather(param, id_tensor): - # 2d-gather torch equivalent: https://stackoverflow.com/questions/52129909/tensorflow-equivalent-of-torch-gather - def gather2d(target, id_tensor): - idx = tf.stack([tf.range(tf.shape(id_tensor)[0], dtype=id_tensor.dtype), id_tensor[:, 0]], axis=-1) - result = tf.gather_nd(target, idx) - return tf.expand_dims(result, axis=-1) - - target = tf.reshape(param, (-1, param.shape[-1])) # reshape 2D - target_shape = id_tensor.shape - - id_tensor = tf.reshape(id_tensor, (-1, 1)) # also 2D-index - result = gather2d(target, id_tensor) - return tf.reshape(result, target_shape) - - ll = torch_gather(rag_logprobs, id_tensor=target) - smooth_obj = tf.reduce_sum(rag_logprobs, axis=-1, keepdims=True) # total sum of all (normalised) logits - - ll, smooth_obj = _mask_pads(ll, smooth_obj) - - # sum over tokens, exclude bos while scoring - if exclude_bos_score and use_bos: - ll = tf.reduce_sum(ll[:, :, 1:], axis=2) - else: - ll = tf.reduce_sum(ll, axis=2) - - smooth_obj = tf.reduce_sum(smooth_obj, axis=2) - ll = tf.math.reduce_logsumexp(ll, axis=1) # logsumexp over docs - smooth_obj = tf.math.reduce_logsumexp(smooth_obj, axis=1) - - nll_loss = -ll - smooth_loss = -smooth_obj - - if reduce_loss: - nll_loss = tf.reduce_sum(nll_loss) - smooth_loss = tf.reduce_sum(smooth_loss) - - eps_i = epsilon / rag_logprobs.shape[-1] - loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss - return loss - - def generate( - self, - input_ids: TFModelInputType | None = None, - attention_mask: tf.Tensor | None = None, - context_input_ids=None, - context_attention_mask=None, - doc_scores=None, - do_deduplication=None, # defaults to True - num_return_sequences=None, # defaults to 1 - num_beams=None, # defaults to 1 - n_docs=None, - **model_kwargs, - ): - """ - Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation - for more information on how to set other generate input parameters - - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - The sequence used as a prompt for the generation. If `input_ids` is not passed, then - `context_input_ids` has to be provided. - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for - tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention - masks?](../glossary#attention-mask) - context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): - Input IDs post-processed from the retrieved documents and the question encoder input_ids by the - retriever. - context_attention_mask (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): - Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the - retriever. If the model has is not initialized with a `retriever` or `input_ids` is not given, - `context_input_ids` and `context_attention_mask` have to be provided to the forward pass. They are - returned by [`~RagRetriever.__call__`]. - doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`): - Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and - `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` or - `input_ids` is not given, `doc_scores` has to be provided to the forward pass. `doc_scores` are - returned by [`~RagRetriever.__call__`]. - do_deduplication (`bool`, *optional*): - Whether or not to deduplicate the generations from different context documents for a given input. Has - to be set to `False` if used while training with distributed backend. - num_return_sequences(`int`, *optional*, defaults to 1): - The number of independently computed returned sequences for each element in the batch. Note that this - is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function, - where we set `num_return_sequences` to `num_beams`. - num_beams (`int`, *optional*, defaults to 1): - Number of beams for beam search. 1 means no beam search. - n_docs (`int`, *optional*, defaults to `config.n_docs`) - Number of documents to retrieve and/or number of documents for which to generate an answer. - kwargs (`dict[str, Any]`, *optional*): - Additional kwargs will be passed to [`~generation.GenerationMixin.generate`] - - Return: - `tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The - second dimension (sequence length) is either equal to `max_length` or shorter if all batches finished early - due to the `eos_token_id`. - """ - - n_docs = n_docs if n_docs is not None else self.config.n_docs - do_deduplication = do_deduplication if do_deduplication is not None else self.config.do_deduplication - num_doc_return_sequences = ( - num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences - ) - num_beams = num_beams if num_beams is not None else self.config.num_beams - - assert input_ids is not None or context_input_ids is not None, ( - " At least one of input_ids or context_input_ids must be given" - ) - - if self.retriever is not None and context_input_ids is None: - question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] - context_input_ids = self.retriever( - input_ids, - question_hidden_states.numpy(), - prefix=self.generator.config.prefix, - n_docs=n_docs, - return_tensors="tf", - )["context_input_ids"] - - hypos = [] - model_kwargs["num_beams"] = num_beams - model_kwargs["num_return_sequences"] = num_beams # put here so that not confused with num_doc_return_sequences - model_kwargs["attention_mask"] = None - - batch_size = input_ids.shape[0] if input_ids is not None else context_input_ids.shape[0] // n_docs - - for index in range(batch_size): - # first, generate beams from documents: - generator_input_ids = context_input_ids[index * n_docs : (index + 1) * n_docs] # (n_docs, max_len) - - output_sequences = self.generator.generate( - generator_input_ids, - **model_kwargs, - ) # n_docs * n_beam, tgt_len - if do_deduplication: - # do_deduplication -- for TF, work on Eager mode only! - output_sequences = tf.stack(list({str(k.numpy().tolist()): k for k in output_sequences}.values())) - - num_candidates = output_sequences.shape[ - 0 - ] # after deduplication, this number can be less than n_docs*n_beam - - # then, run model forwards to get nll scores: - if input_ids is not None: - new_input_ids = tf.tile(input_ids[index : index + 1], (num_candidates, 1)) - outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True) - else: # input_ids is None, need context_input_ids/mask and doc_scores - assert context_attention_mask is not None, ( - "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you" - " can set a retriever using the `set_retriever(...)` function." - ) - assert doc_scores is not None, ( - "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a" - " retriever using the `set_retriever(...)` function." - ) - - individual_input_ids = tf.tile( - generator_input_ids, (num_candidates, 1) - ) # (num_candidates*n_docs, max_len) - - individual_attention_mask = context_attention_mask[index * n_docs : (index + 1) * n_docs] - individual_attention_mask = tf.tile(individual_attention_mask, (num_candidates, 1)) - - individual_doc_scores = doc_scores[index : (index + 1), :] # doc_scores.shape = [batch, n_docs] - individual_doc_scores = tf.tile(individual_doc_scores, (num_candidates, 1)) # [num_candidates, n_docs] - - outputs = self( - input_ids=None, - context_input_ids=individual_input_ids, - context_attention_mask=individual_attention_mask, - doc_scores=individual_doc_scores, - labels=output_sequences, - exclude_bos_score=True, - ) - - top_cand_inds = tf.math.top_k((-outputs["loss"]), k=num_doc_return_sequences)[1] - - # add hypothesis - hypos.append(tf.gather(output_sequences, top_cand_inds)) - - return self._cat_and_pad(hypos, pad_token_id=self.config.generator.pad_token_id) - - @staticmethod - def _cat_and_pad(tensors, pad_token_id): - # used by generate(): tensors is a (batched) list of (candidates, len); len is varied across batch - - # Initialize padded tensor with shape ( all_candidates , max_candidate_length ), - # where all_candidates counted from all inputs - new_shape = sum([t.shape[0] for t in tensors]), max([t.shape[1] for t in tensors]) - output = tf.fill(new_shape, pad_token_id) - - # Normal tensor doesn't support slice assignment, so we need tf.Variable - output = tf.Variable(output) - - # Assign, and then convert back to tensor - ind = 0 - for t in tensors: - output[ind : ind + t.shape[0], : t.shape[1]].assign(t) - ind += t.shape[0] - - output = tf.convert_to_tensor(output) - return tf.cast(output, tensors[0][0][0].dtype) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "rag", None) is not None: - with tf.name_scope(self.rag.name): - self.rag.build(None) - - -__all__ = ["TFRagModel", "TFRagPreTrainedModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"] diff --git a/src/transformers/models/regnet/modeling_flax_regnet.py b/src/transformers/models/regnet/modeling_flax_regnet.py deleted file mode 100644 index 2cc3707fa51a..000000000000 --- a/src/transformers/models/regnet/modeling_flax_regnet.py +++ /dev/null @@ -1,822 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from functools import partial -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.traverse_util import flatten_dict, unflatten_dict - -from transformers import RegNetConfig -from transformers.modeling_flax_outputs import ( - FlaxBaseModelOutputWithNoAttention, - FlaxBaseModelOutputWithPooling, - FlaxBaseModelOutputWithPoolingAndNoAttention, - FlaxImageClassifierOutputWithNoAttention, -) -from transformers.modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, -) - - -REGNET_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`RegNetConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -REGNET_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`RegNetImageProcessor.__call__`] for details. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.resnet.modeling_flax_resnet.Identity -class Identity(nn.Module): - """Identity function.""" - - @nn.compact - def __call__(self, x, **kwargs): - return x - - -class FlaxRegNetConvLayer(nn.Module): - out_channels: int - kernel_size: int = 3 - stride: int = 1 - groups: int = 1 - activation: Optional[str] = "relu" - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.convolution = nn.Conv( - self.out_channels, - kernel_size=(self.kernel_size, self.kernel_size), - strides=self.stride, - padding=self.kernel_size // 2, - feature_group_count=self.groups, - use_bias=False, - kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), - dtype=self.dtype, - ) - self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) - self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity() - - def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - hidden_state = self.convolution(hidden_state) - hidden_state = self.normalization(hidden_state, use_running_average=deterministic) - hidden_state = self.activation_func(hidden_state) - return hidden_state - - -class FlaxRegNetEmbeddings(nn.Module): - config: RegNetConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.embedder = FlaxRegNetConvLayer( - self.config.embedding_size, - kernel_size=3, - stride=2, - activation=self.config.hidden_act, - dtype=self.dtype, - ) - - def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - num_channels = pixel_values.shape[-1] - if num_channels != self.config.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - hidden_state = self.embedder(pixel_values, deterministic=deterministic) - return hidden_state - - -# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetShortCut with ResNet->RegNet -class FlaxRegNetShortCut(nn.Module): - """ - RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to - downsample the input using `stride=2`. - """ - - out_channels: int - stride: int = 2 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.convolution = nn.Conv( - self.out_channels, - kernel_size=(1, 1), - strides=self.stride, - use_bias=False, - kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), - dtype=self.dtype, - ) - self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) - - def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - hidden_state = self.convolution(x) - hidden_state = self.normalization(hidden_state, use_running_average=deterministic) - return hidden_state - - -class FlaxRegNetSELayerCollection(nn.Module): - in_channels: int - reduced_channels: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.conv_1 = nn.Conv( - self.reduced_channels, - kernel_size=(1, 1), - kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), - dtype=self.dtype, - name="0", - ) # 0 is the name used in corresponding pytorch implementation - self.conv_2 = nn.Conv( - self.in_channels, - kernel_size=(1, 1), - kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), - dtype=self.dtype, - name="2", - ) # 2 is the name used in corresponding pytorch implementation - - def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: - hidden_state = self.conv_1(hidden_state) - hidden_state = nn.relu(hidden_state) - hidden_state = self.conv_2(hidden_state) - attention = nn.sigmoid(hidden_state) - - return attention - - -class FlaxRegNetSELayer(nn.Module): - """ - Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://huggingface.co/papers/1709.01507). - """ - - in_channels: int - reduced_channels: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.pooler = partial(nn.avg_pool, padding=((0, 0), (0, 0))) - self.attention = FlaxRegNetSELayerCollection(self.in_channels, self.reduced_channels, dtype=self.dtype) - - def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: - pooled = self.pooler( - hidden_state, - window_shape=(hidden_state.shape[1], hidden_state.shape[2]), - strides=(hidden_state.shape[1], hidden_state.shape[2]), - ) - attention = self.attention(pooled) - hidden_state = hidden_state * attention - return hidden_state - - -class FlaxRegNetXLayerCollection(nn.Module): - config: RegNetConfig - out_channels: int - stride: int = 1 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - groups = max(1, self.out_channels // self.config.groups_width) - - self.layer = [ - FlaxRegNetConvLayer( - self.out_channels, - kernel_size=1, - activation=self.config.hidden_act, - dtype=self.dtype, - name="0", - ), - FlaxRegNetConvLayer( - self.out_channels, - stride=self.stride, - groups=groups, - activation=self.config.hidden_act, - dtype=self.dtype, - name="1", - ), - FlaxRegNetConvLayer( - self.out_channels, - kernel_size=1, - activation=None, - dtype=self.dtype, - name="2", - ), - ] - - def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - for layer in self.layer: - hidden_state = layer(hidden_state, deterministic=deterministic) - return hidden_state - - -class FlaxRegNetXLayer(nn.Module): - """ - RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1. - """ - - config: RegNetConfig - in_channels: int - out_channels: int - stride: int = 1 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 - self.shortcut = ( - FlaxRegNetShortCut( - self.out_channels, - stride=self.stride, - dtype=self.dtype, - ) - if should_apply_shortcut - else Identity() - ) - self.layer = FlaxRegNetXLayerCollection( - self.config, - in_channels=self.in_channels, - out_channels=self.out_channels, - stride=self.stride, - dtype=self.dtype, - ) - self.activation_func = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - residual = hidden_state - hidden_state = self.layer(hidden_state) - residual = self.shortcut(residual, deterministic=deterministic) - hidden_state += residual - hidden_state = self.activation_func(hidden_state) - return hidden_state - - -class FlaxRegNetYLayerCollection(nn.Module): - config: RegNetConfig - in_channels: int - out_channels: int - stride: int = 1 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - groups = max(1, self.out_channels // self.config.groups_width) - - self.layer = [ - FlaxRegNetConvLayer( - self.out_channels, - kernel_size=1, - activation=self.config.hidden_act, - dtype=self.dtype, - name="0", - ), - FlaxRegNetConvLayer( - self.out_channels, - stride=self.stride, - groups=groups, - activation=self.config.hidden_act, - dtype=self.dtype, - name="1", - ), - FlaxRegNetSELayer( - self.out_channels, - reduced_channels=int(round(self.in_channels / 4)), - dtype=self.dtype, - name="2", - ), - FlaxRegNetConvLayer( - self.out_channels, - kernel_size=1, - activation=None, - dtype=self.dtype, - name="3", - ), - ] - - def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: - for layer in self.layer: - hidden_state = layer(hidden_state) - return hidden_state - - -class FlaxRegNetYLayer(nn.Module): - """ - RegNet's Y layer: an X layer with Squeeze and Excitation. - """ - - config: RegNetConfig - in_channels: int - out_channels: int - stride: int = 1 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 - - self.shortcut = ( - FlaxRegNetShortCut( - self.out_channels, - stride=self.stride, - dtype=self.dtype, - ) - if should_apply_shortcut - else Identity() - ) - self.layer = FlaxRegNetYLayerCollection( - self.config, - in_channels=self.in_channels, - out_channels=self.out_channels, - stride=self.stride, - dtype=self.dtype, - ) - self.activation_func = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - residual = hidden_state - hidden_state = self.layer(hidden_state) - residual = self.shortcut(residual, deterministic=deterministic) - hidden_state += residual - hidden_state = self.activation_func(hidden_state) - return hidden_state - - -class FlaxRegNetStageLayersCollection(nn.Module): - """ - A RegNet stage composed by stacked layers. - """ - - config: RegNetConfig - in_channels: int - out_channels: int - stride: int = 2 - depth: int = 2 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - layer = FlaxRegNetXLayer if self.config.layer_type == "x" else FlaxRegNetYLayer - - layers = [ - # downsampling is done in the first layer with stride of 2 - layer( - self.config, - self.in_channels, - self.out_channels, - stride=self.stride, - dtype=self.dtype, - name="0", - ) - ] - - for i in range(self.depth - 1): - layers.append( - layer( - self.config, - self.out_channels, - self.out_channels, - dtype=self.dtype, - name=str(i + 1), - ) - ) - - self.layers = layers - - def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - hidden_state = x - for layer in self.layers: - hidden_state = layer(hidden_state, deterministic=deterministic) - return hidden_state - - -# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStage with ResNet->RegNet -class FlaxRegNetStage(nn.Module): - """ - A RegNet stage composed by stacked layers. - """ - - config: RegNetConfig - in_channels: int - out_channels: int - stride: int = 2 - depth: int = 2 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.layers = FlaxRegNetStageLayersCollection( - self.config, - in_channels=self.in_channels, - out_channels=self.out_channels, - stride=self.stride, - depth=self.depth, - dtype=self.dtype, - ) - - def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - return self.layers(x, deterministic=deterministic) - - -# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStageCollection with ResNet->RegNet -class FlaxRegNetStageCollection(nn.Module): - config: RegNetConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:]) - stages = [ - FlaxRegNetStage( - self.config, - self.config.embedding_size, - self.config.hidden_sizes[0], - stride=2 if self.config.downsample_in_first_stage else 1, - depth=self.config.depths[0], - dtype=self.dtype, - name="0", - ) - ] - - for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])): - stages.append( - FlaxRegNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1)) - ) - - self.stages = stages - - def __call__( - self, - hidden_state: jnp.ndarray, - output_hidden_states: bool = False, - deterministic: bool = True, - ) -> FlaxBaseModelOutputWithNoAttention: - hidden_states = () if output_hidden_states else None - - for stage_module in self.stages: - if output_hidden_states: - hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) - - hidden_state = stage_module(hidden_state, deterministic=deterministic) - - return hidden_state, hidden_states - - -# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetEncoder with ResNet->RegNet -class FlaxRegNetEncoder(nn.Module): - config: RegNetConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.stages = FlaxRegNetStageCollection(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_state: jnp.ndarray, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ) -> FlaxBaseModelOutputWithNoAttention: - hidden_state, hidden_states = self.stages( - hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic - ) - - if output_hidden_states: - hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) - - if not return_dict: - return tuple(v for v in [hidden_state, hidden_states] if v is not None) - - return FlaxBaseModelOutputWithNoAttention( - last_hidden_state=hidden_state, - hidden_states=hidden_states, - ) - - -# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetPreTrainedModel with ResNet->RegNet,resnet->regnet,RESNET->REGNET -class FlaxRegNetPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RegNetConfig - base_model_prefix = "regnet" - main_input_name = "pixel_values" - module_class: nn.Module = None - - def __init__( - self, - config: RegNetConfig, - input_shape=(1, 224, 224, 3), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - if input_shape is None: - input_shape = (1, config.image_size, config.image_size, config.num_channels) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - pixel_values = jnp.zeros(input_shape, dtype=self.dtype) - - rngs = {"params": rng} - - random_params = self.module.init(rngs, pixel_values, return_dict=False) - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) - def __call__( - self, - pixel_values, - params: Optional[dict] = None, - train: bool = False, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) - - # Handle any PRNG if needed - rngs = {} - - return self.module.apply( - { - "params": params["params"] if params is not None else self.params["params"], - "batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"], - }, - jnp.array(pixel_values, dtype=jnp.float32), - not train, - output_hidden_states, - return_dict, - rngs=rngs, - mutable=["batch_stats"] if train else False, # Returning tuple with batch_stats only when train is True - ) - - -# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetModule with ResNet->RegNet -class FlaxRegNetModule(nn.Module): - config: RegNetConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.embedder = FlaxRegNetEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxRegNetEncoder(self.config, dtype=self.dtype) - - # Adaptive average pooling used in resnet - self.pooler = partial( - nn.avg_pool, - padding=((0, 0), (0, 0)), - ) - - def __call__( - self, - pixel_values, - deterministic: bool = True, - output_hidden_states: bool = False, - return_dict: bool = True, - ) -> FlaxBaseModelOutputWithPoolingAndNoAttention: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - embedding_output = self.embedder(pixel_values, deterministic=deterministic) - - encoder_outputs = self.encoder( - embedding_output, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - last_hidden_state = encoder_outputs[0] - - pooled_output = self.pooler( - last_hidden_state, - window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]), - strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]), - ).transpose(0, 3, 1, 2) - - last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2) - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return FlaxBaseModelOutputWithPoolingAndNoAttention( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - ) - - -@add_start_docstrings( - "The bare RegNet model outputting raw features without any specific head on top.", - REGNET_START_DOCSTRING, -) -class FlaxRegNetModel(FlaxRegNetPreTrainedModel): - module_class = FlaxRegNetModule - - -FLAX_VISION_MODEL_DOCSTRING = """ - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, FlaxRegNetModel - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040") - >>> model = FlaxRegNetModel.from_pretrained("facebook/regnet-y-040") - - >>> inputs = image_processor(images=image, return_tensors="np") - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.last_hidden_state - ``` -""" - -overwrite_call_docstring(FlaxRegNetModel, FLAX_VISION_MODEL_DOCSTRING) -append_replace_return_docstrings( - FlaxRegNetModel, - output_type=FlaxBaseModelOutputWithPooling, - config_class=RegNetConfig, -) - - -# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetClassifierCollection with ResNet->RegNet -class FlaxRegNetClassifierCollection(nn.Module): - config: RegNetConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1") - - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - return self.classifier(x) - - -# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetForImageClassificationModule with ResNet->RegNet,resnet->regnet,RESNET->REGNET -class FlaxRegNetForImageClassificationModule(nn.Module): - config: RegNetConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.regnet = FlaxRegNetModule(config=self.config, dtype=self.dtype) - - if self.config.num_labels > 0: - self.classifier = FlaxRegNetClassifierCollection(self.config, dtype=self.dtype) - else: - self.classifier = Identity() - - def __call__( - self, - pixel_values=None, - deterministic: bool = True, - output_hidden_states=None, - return_dict=None, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.regnet( - pixel_values, - deterministic=deterministic, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs.pooler_output if return_dict else outputs[1] - - logits = self.classifier(pooled_output[:, :, 0, 0]) - - if not return_dict: - output = (logits,) + outputs[2:] - return output - - return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states) - - -@add_start_docstrings( - """ - RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for - ImageNet. - """, - REGNET_START_DOCSTRING, -) -class FlaxRegNetForImageClassification(FlaxRegNetPreTrainedModel): - module_class = FlaxRegNetForImageClassificationModule - - -FLAX_VISION_CLASSIF_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoImageProcessor, FlaxRegNetForImageClassification - >>> from PIL import Image - >>> import jax - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040") - >>> model = FlaxRegNetForImageClassification.from_pretrained("facebook/regnet-y-040") - - >>> inputs = image_processor(images=image, return_tensors="np") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) - >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) - ``` -""" - -overwrite_call_docstring(FlaxRegNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) -append_replace_return_docstrings( - FlaxRegNetForImageClassification, - output_type=FlaxImageClassifierOutputWithNoAttention, - config_class=RegNetConfig, -) - - -__all__ = ["FlaxRegNetForImageClassification", "FlaxRegNetModel", "FlaxRegNetPreTrainedModel"] diff --git a/src/transformers/models/regnet/modeling_tf_regnet.py b/src/transformers/models/regnet/modeling_tf_regnet.py deleted file mode 100644 index 13714b4e69aa..000000000000 --- a/src/transformers/models/regnet/modeling_tf_regnet.py +++ /dev/null @@ -1,611 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow RegNet model.""" - -from typing import Optional, Union - -import tensorflow as tf - -from ...activations_tf import ACT2FN -from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithNoAttention, - TFBaseModelOutputWithPoolingAndNoAttention, - TFSequenceClassifierOutput, -) -from ...modeling_tf_utils import ( - TFPreTrainedModel, - TFSequenceClassificationLoss, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list -from ...utils import logging -from .configuration_regnet import RegNetConfig - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "RegNetConfig" - -# Base docstring -_CHECKPOINT_FOR_DOC = "facebook/regnet-y-040" -_EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7] - -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040" -_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" - - -class TFRegNetConvLayer(keras.layers.Layer): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - stride: int = 1, - groups: int = 1, - activation: Optional[str] = "relu", - **kwargs, - ): - super().__init__(**kwargs) - # The padding and conv has been verified in - # https://colab.research.google.com/gist/sayakpaul/854bc10eeaf21c9ee2119e0b9f3841a7/scratchpad.ipynb - self.padding = keras.layers.ZeroPadding2D(padding=kernel_size // 2) - self.convolution = keras.layers.Conv2D( - filters=out_channels, - kernel_size=kernel_size, - strides=stride, - padding="VALID", - groups=groups, - use_bias=False, - name="convolution", - ) - self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") - self.activation = ACT2FN[activation] if activation is not None else tf.identity - self.in_channels = in_channels - self.out_channels = out_channels - - def call(self, hidden_state): - hidden_state = self.convolution(self.padding(hidden_state)) - hidden_state = self.normalization(hidden_state) - hidden_state = self.activation(hidden_state) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convolution", None) is not None: - with tf.name_scope(self.convolution.name): - self.convolution.build([None, None, None, self.in_channels]) - if getattr(self, "normalization", None) is not None: - with tf.name_scope(self.normalization.name): - self.normalization.build([None, None, None, self.out_channels]) - - -class TFRegNetEmbeddings(keras.layers.Layer): - """ - RegNet Embeddings (stem) composed of a single aggressive convolution. - """ - - def __init__(self, config: RegNetConfig, **kwargs): - super().__init__(**kwargs) - self.num_channels = config.num_channels - self.embedder = TFRegNetConvLayer( - in_channels=config.num_channels, - out_channels=config.embedding_size, - kernel_size=3, - stride=2, - activation=config.hidden_act, - name="embedder", - ) - - def call(self, pixel_values): - num_channels = shape_list(pixel_values)[1] - if tf.executing_eagerly() and num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - - # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - # shape = (batch_size, in_height, in_width, in_channels=num_channels) - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - hidden_state = self.embedder(pixel_values) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embedder", None) is not None: - with tf.name_scope(self.embedder.name): - self.embedder.build(None) - - -class TFRegNetShortCut(keras.layers.Layer): - """ - RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to - downsample the input using `stride=2`. - """ - - def __init__(self, in_channels: int, out_channels: int, stride: int = 2, **kwargs): - super().__init__(**kwargs) - self.convolution = keras.layers.Conv2D( - filters=out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution" - ) - self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") - self.in_channels = in_channels - self.out_channels = out_channels - - def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor: - return self.normalization(self.convolution(inputs), training=training) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convolution", None) is not None: - with tf.name_scope(self.convolution.name): - self.convolution.build([None, None, None, self.in_channels]) - if getattr(self, "normalization", None) is not None: - with tf.name_scope(self.normalization.name): - self.normalization.build([None, None, None, self.out_channels]) - - -class TFRegNetSELayer(keras.layers.Layer): - """ - Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://huggingface.co/papers/1709.01507). - """ - - def __init__(self, in_channels: int, reduced_channels: int, **kwargs): - super().__init__(**kwargs) - self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler") - self.attention = [ - keras.layers.Conv2D(filters=reduced_channels, kernel_size=1, activation="relu", name="attention.0"), - keras.layers.Conv2D(filters=in_channels, kernel_size=1, activation="sigmoid", name="attention.2"), - ] - self.in_channels = in_channels - self.reduced_channels = reduced_channels - - def call(self, hidden_state): - # [batch_size, h, w, num_channels] -> [batch_size, 1, 1, num_channels] - pooled = self.pooler(hidden_state) - for layer_module in self.attention: - pooled = layer_module(pooled) - hidden_state = hidden_state * pooled - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build((None, None, None, None)) - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention[0].name): - self.attention[0].build([None, None, None, self.in_channels]) - with tf.name_scope(self.attention[1].name): - self.attention[1].build([None, None, None, self.reduced_channels]) - - -class TFRegNetXLayer(keras.layers.Layer): - """ - RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1. - """ - - def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs): - super().__init__(**kwargs) - should_apply_shortcut = in_channels != out_channels or stride != 1 - groups = max(1, out_channels // config.groups_width) - self.shortcut = ( - TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut") - if should_apply_shortcut - else keras.layers.Activation("linear", name="shortcut") - ) - # `self.layers` instead of `self.layer` because that is a reserved argument. - self.layers = [ - TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"), - TFRegNetConvLayer( - out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1" - ), - TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.2"), - ] - self.activation = ACT2FN[config.hidden_act] - - def call(self, hidden_state): - residual = hidden_state - for layer_module in self.layers: - hidden_state = layer_module(hidden_state) - residual = self.shortcut(residual) - hidden_state += residual - hidden_state = self.activation(hidden_state) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "shortcut", None) is not None: - with tf.name_scope(self.shortcut.name): - self.shortcut.build(None) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFRegNetYLayer(keras.layers.Layer): - """ - RegNet's Y layer: an X layer with Squeeze and Excitation. - """ - - def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs): - super().__init__(**kwargs) - should_apply_shortcut = in_channels != out_channels or stride != 1 - groups = max(1, out_channels // config.groups_width) - self.shortcut = ( - TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut") - if should_apply_shortcut - else keras.layers.Activation("linear", name="shortcut") - ) - self.layers = [ - TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"), - TFRegNetConvLayer( - out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1" - ), - TFRegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4)), name="layer.2"), - TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.3"), - ] - self.activation = ACT2FN[config.hidden_act] - - def call(self, hidden_state): - residual = hidden_state - for layer_module in self.layers: - hidden_state = layer_module(hidden_state) - residual = self.shortcut(residual) - hidden_state += residual - hidden_state = self.activation(hidden_state) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "shortcut", None) is not None: - with tf.name_scope(self.shortcut.name): - self.shortcut.build(None) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFRegNetStage(keras.layers.Layer): - """ - A RegNet stage composed by stacked layers. - """ - - def __init__( - self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs - ): - super().__init__(**kwargs) - - layer = TFRegNetXLayer if config.layer_type == "x" else TFRegNetYLayer - self.layers = [ - # downsampling is done in the first layer with stride of 2 - layer(config, in_channels, out_channels, stride=stride, name="layers.0"), - *[layer(config, out_channels, out_channels, name=f"layers.{i + 1}") for i in range(depth - 1)], - ] - - def call(self, hidden_state): - for layer_module in self.layers: - hidden_state = layer_module(hidden_state) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFRegNetEncoder(keras.layers.Layer): - def __init__(self, config: RegNetConfig, **kwargs): - super().__init__(**kwargs) - self.stages = [] - # based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input - self.stages.append( - TFRegNetStage( - config, - config.embedding_size, - config.hidden_sizes[0], - stride=2 if config.downsample_in_first_stage else 1, - depth=config.depths[0], - name="stages.0", - ) - ) - in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) - for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, config.depths[1:])): - self.stages.append(TFRegNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i + 1}")) - - def call( - self, hidden_state: tf.Tensor, output_hidden_states: bool = False, return_dict: bool = True - ) -> TFBaseModelOutputWithNoAttention: - hidden_states = () if output_hidden_states else None - - for stage_module in self.stages: - if output_hidden_states: - hidden_states = hidden_states + (hidden_state,) - - hidden_state = stage_module(hidden_state) - - if output_hidden_states: - hidden_states = hidden_states + (hidden_state,) - - if not return_dict: - return tuple(v for v in [hidden_state, hidden_states] if v is not None) - - return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - for stage in self.stages: - with tf.name_scope(stage.name): - stage.build(None) - - -@keras_serializable -class TFRegNetMainLayer(keras.layers.Layer): - config_class = RegNetConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.config = config - self.embedder = TFRegNetEmbeddings(config, name="embedder") - self.encoder = TFRegNetEncoder(config, name="encoder") - self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler") - - @unpack_inputs - def call( - self, - pixel_values: tf.Tensor, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> TFBaseModelOutputWithPoolingAndNoAttention: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - embedding_output = self.embedder(pixel_values, training=training) - - encoder_outputs = self.encoder( - embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training - ) - - last_hidden_state = encoder_outputs[0] - pooled_output = self.pooler(last_hidden_state) - - # Change to NCHW output format have uniformity in the modules - pooled_output = tf.transpose(pooled_output, perm=(0, 3, 1, 2)) - last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2)) - - # Change the other hidden state outputs to NCHW as well - if output_hidden_states: - hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]) - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return TFBaseModelOutputWithPoolingAndNoAttention( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embedder", None) is not None: - with tf.name_scope(self.embedder.name): - self.embedder.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build((None, None, None, None)) - - -class TFRegNetPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RegNetConfig - base_model_prefix = "regnet" - main_input_name = "pixel_values" - - @property - def input_signature(self): - return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 224, 224), dtype=tf.float32)} - - -REGNET_START_DOCSTRING = r""" - This model is a Tensorflow - [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a - regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`RegNetConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -REGNET_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`ConveNextImageProcessor.__call__`] for details. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare RegNet model outputting raw features without any specific head on top.", - REGNET_START_DOCSTRING, -) -class TFRegNetModel(TFRegNetPreTrainedModel): - def __init__(self, config: RegNetConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.regnet = TFRegNetMainLayer(config, name="regnet") - - @unpack_inputs - @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPoolingAndNoAttention, - config_class=_CONFIG_FOR_DOC, - modality="vision", - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - def call( - self, - pixel_values: tf.Tensor, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[TFBaseModelOutputWithPoolingAndNoAttention, tuple[tf.Tensor]]: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.regnet( - pixel_values=pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - if not return_dict: - return (outputs[0],) + outputs[1:] - - return TFBaseModelOutputWithPoolingAndNoAttention( - last_hidden_state=outputs.last_hidden_state, - pooler_output=outputs.pooler_output, - hidden_states=outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "regnet", None) is not None: - with tf.name_scope(self.regnet.name): - self.regnet.build(None) - - -@add_start_docstrings( - """ - RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for - ImageNet. - """, - REGNET_START_DOCSTRING, -) -class TFRegNetForImageClassification(TFRegNetPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: RegNetConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.regnet = TFRegNetMainLayer(config, name="regnet") - # classification head - self.classifier = [ - keras.layers.Flatten(), - keras.layers.Dense(config.num_labels, name="classifier.1") if config.num_labels > 0 else tf.identity, - ] - - @unpack_inputs - @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, - ) - def call( - self, - pixel_values: Optional[tf.Tensor] = None, - labels: Optional[tf.Tensor] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[TFSequenceClassifierOutput, tuple[tf.Tensor]]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.regnet( - pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training - ) - - pooled_output = outputs.pooler_output if return_dict else outputs[1] - - flattened_output = self.classifier[0](pooled_output) - logits = self.classifier[1](flattened_output) - - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "regnet", None) is not None: - with tf.name_scope(self.regnet.name): - self.regnet.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier[1].name): - self.classifier[1].build([None, None, None, self.config.hidden_sizes[-1]]) - - -__all__ = ["TFRegNetForImageClassification", "TFRegNetModel", "TFRegNetPreTrainedModel"] diff --git a/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py b/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py deleted file mode 100755 index 369388c540f9..000000000000 --- a/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,62 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert RemBERT checkpoint.""" - -import argparse - -import torch - -from transformers import RemBertConfig, RemBertModel, load_tf_weights_in_rembert -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_rembert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): - # Initialise PyTorch model - config = RemBertConfig.from_json_file(bert_config_file) - print(f"Building PyTorch model from configuration: {str(config)}") - model = RemBertModel(config) - - # Load weights from tf checkpoint - load_tf_weights_in_rembert(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - torch.save(model.state_dict(), pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--rembert_config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained RemBERT model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_rembert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.rembert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/rembert/modeling_tf_rembert.py b/src/transformers/models/rembert/modeling_tf_rembert.py deleted file mode 100644 index baf7b6e8adc9..000000000000 --- a/src/transformers/models/rembert/modeling_tf_rembert.py +++ /dev/null @@ -1,1720 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 RemBERT model.""" - -from __future__ import annotations - -import math - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFBaseModelOutputWithPoolingAndCrossAttentions, - TFCausalLMOutputWithCrossAttentions, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_rembert import RemBertConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "RemBertConfig" - - -class TFRemBertEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config: RemBertConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.input_embedding_size = config.input_embedding_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.input_embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.input_embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.input_embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.input_embedding_size]) - - def call( - self, - input_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - past_key_values_length=0, - training: bool = False, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - position_ids = tf.expand_dims( - tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 - ) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->RemBert -class TFRemBertSelfAttention(keras.layers.Layer): - def __init__(self, config: RemBertConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - - self.is_decoder = config.is_decoder - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concat([past_key_value[0], key_layer], axis=2) - value_layer = tf.concat([past_key_value[1], value_layer], axis=2) - else: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFRemBertModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->RemBert -class TFRemBertSelfOutput(keras.layers.Layer): - def __init__(self, config: RemBertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->RemBert -class TFRemBertAttention(keras.layers.Layer): - def __init__(self, config: RemBertConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFRemBertSelfAttention(config, name="self") - self.dense_output = TFRemBertSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - # add attentions (possibly with past_key_value) if we output them - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->RemBert -class TFRemBertIntermediate(keras.layers.Layer): - def __init__(self, config: RemBertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->RemBert -class TFRemBertOutput(keras.layers.Layer): - def __init__(self, config: RemBertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->RemBert -class TFRemBertLayer(keras.layers.Layer): - def __init__(self, config: RemBertConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFRemBertAttention(config, name="attention") - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TFRemBertAttention(config, name="crossattention") - self.intermediate = TFRemBertIntermediate(config, name="intermediate") - self.bert_output = TFRemBertOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_value: tuple[tf.Tensor] | None, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=self_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - input_tensor=attention_output, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + outputs # add attentions if we output them - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - if getattr(self, "crossattention", None) is not None: - with tf.name_scope(self.crossattention.name): - self.crossattention.build(None) - - -class TFRemBertEncoder(keras.layers.Layer): - def __init__(self, config: RemBertConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - - self.embedding_hidden_mapping_in = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="embedding_hidden_mapping_in", - ) - self.layer = [TFRemBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_values: tuple[tuple[tf.Tensor]], - use_cache: bool, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states) - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - past_key_value = past_key_values[i] if past_key_values is not None else None - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - if self.config.add_cross_attention and encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None - ) - - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embedding_hidden_mapping_in", None) is not None: - with tf.name_scope(self.embedding_hidden_mapping_in.name): - self.embedding_hidden_mapping_in.build([None, None, self.config.input_embedding_size]) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->RemBert -class TFRemBertPooler(keras.layers.Layer): - def __init__(self, config: RemBertConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFRemBertLMPredictionHead(keras.layers.Layer): - def __init__(self, config: RemBertConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.initializer_range = config.initializer_range - self.output_embedding_size = config.output_embedding_size - self.dense = keras.layers.Dense( - config.output_embedding_size, kernel_initializer=get_initializer(self.initializer_range), name="dense" - ) - if isinstance(config.hidden_act, str): - self.activation = get_tf_activation(config.hidden_act) - else: - self.activation = config.hidden_act - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - - def build(self, input_shape=None): - self.decoder = self.add_weight( - name="decoder/weight", - shape=[self.config.vocab_size, self.output_embedding_size], - initializer=get_initializer(self.initializer_range), - ) - self.decoder_bias = self.add_weight( - shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias" - ) - - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, self.config.output_embedding_size]) - - def get_output_embeddings(self) -> keras.layers.Layer: - return self - - def set_output_embeddings(self, value): - self.decoder = value - self.decoder.vocab_size = shape_list(value)[0] - - def get_bias(self) -> dict[str, tf.Variable]: - return {"decoder_bias": self.decoder_bias} - - def set_bias(self, value: tf.Variable): - self.decoder_bias = value["decoder_bias"] - self.config.vocab_size = shape_list(value["decoder_bias"])[0] - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.activation(hidden_states) - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.output_embedding_size]) - hidden_states = self.LayerNorm(hidden_states) - hidden_states = tf.matmul(a=hidden_states, b=self.decoder, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.decoder_bias) - return hidden_states - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->RemBert -class TFRemBertMLMHead(keras.layers.Layer): - def __init__(self, config: RemBertConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.predictions = TFRemBertLMPredictionHead(config, input_embeddings, name="predictions") - - def call(self, sequence_output: tf.Tensor) -> tf.Tensor: - prediction_scores = self.predictions(hidden_states=sequence_output) - - return prediction_scores - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - -@keras_serializable -class TFRemBertMainLayer(keras.layers.Layer): - config_class = RemBertConfig - - def __init__(self, config: RemBertConfig, add_pooling_layer: bool = True, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.is_decoder = config.is_decoder - - self.embeddings = TFRemBertEmbeddings(config, name="embeddings") - self.encoder = TFRemBertEncoder(config, name="encoder") - self.pooler = TFRemBertPooler(config, name="pooler") if add_pooling_layer else None - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - if not self.config.is_decoder: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_key_values_length = 0 - past_key_values = [None] * len(self.encoder.layer) - else: - past_key_values_length = shape_list(past_key_values[0][0])[-2] - - if attention_mask is None: - attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - training=training, - ) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(attention_mask) - - mask_seq_length = seq_length + past_key_values_length - # Copied from `modeling_tf_t5.py` - # Provided a padding mask of dimensions [batch_size, mask_seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if self.is_decoder: - seq_ids = tf.range(mask_seq_length) - causal_mask = tf.less_equal( - tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), - seq_ids[None, :, None], - ) - causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) - extended_attention_mask = causal_mask * attention_mask[:, None, :] - attention_mask_shape = shape_list(extended_attention_mask) - extended_attention_mask = tf.reshape( - extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) - ) - if past_key_values[0] is not None: - # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] - extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] - else: - extended_attention_mask = tf.reshape( - attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.is_decoder and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) - num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) - if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, - # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) - - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -class TFRemBertPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RemBertConfig - base_model_prefix = "rembert" - - -REMBERT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`RemBertConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -REMBERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare RemBERT Model transformer outputting raw hidden-states without any specific head on top.", - REMBERT_START_DOCSTRING, -) -class TFRemBertModel(TFRemBertPreTrainedModel): - def __init__(self, config: RemBertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.rembert = TFRemBertMainLayer(config, name="rembert") - - @unpack_inputs - @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="google/rembert", - output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - """ - outputs = self.rembert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "rembert", None) is not None: - with tf.name_scope(self.rembert.name): - self.rembert.build(None) - - -@add_start_docstrings("""RemBERT Model with a `language modeling` head on top.""", REMBERT_START_DOCSTRING) -class TFRemBertForMaskedLM(TFRemBertPreTrainedModel, TFMaskedLanguageModelingLoss): - def __init__(self, config: RemBertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if config.is_decoder: - logger.warning( - "If you want to use `TFRemBertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.rembert = TFRemBertMainLayer(config, name="rembert", add_pooling_layer=False) - self.mlm = TFRemBertMLMHead(config, input_embeddings=self.rembert.embeddings, name="mlm___cls") - - def get_lm_head(self) -> keras.layers.Layer: - return self.mlm.predictions - - @unpack_inputs - @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="google/rembert", - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - outputs = self.rembert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.mlm(sequence_output=sequence_output, training=training) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "rembert", None) is not None: - with tf.name_scope(self.rembert.name): - self.rembert.build(None) - if getattr(self, "mlm", None) is not None: - with tf.name_scope(self.mlm.name): - self.mlm.build(None) - - -@add_start_docstrings( - """RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING -) -class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config: RemBertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if not config.is_decoder: - logger.warning("If you want to use `TFRemBertForCausalLM` as a standalone, add `is_decoder=True.`") - - self.rembert = TFRemBertMainLayer(config, name="rembert", add_pooling_layer=False) - self.mlm = TFRemBertMLMHead(config, input_embeddings=self.rembert.embeddings, name="mlm___cls") - - def get_lm_head(self) -> keras.layers.Layer: - return self.mlm.predictions - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = tf.ones(input_shape) - - # cut decoder_input_ids if past is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - - @unpack_inputs - @add_code_sample_docstrings( - checkpoint="google/rembert", - output_type=TFCausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - outputs = self.rembert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.mlm(sequence_output=sequence_output, training=training) - loss = None - - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithCrossAttentions( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "rembert", None) is not None: - with tf.name_scope(self.rembert.name): - self.rembert.build(None) - if getattr(self, "mlm", None) is not None: - with tf.name_scope(self.mlm.name): - self.mlm.build(None) - - -@add_start_docstrings( - """ - RemBERT Model transformer with a sequence classification/regression head on top e.g., for GLUE tasks. - """, - REMBERT_START_DOCSTRING, -) -class TFRemBertForSequenceClassification(TFRemBertPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: RemBertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.rembert = TFRemBertMainLayer(config, name="rembert") - self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob) - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="google/rembert", - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.rembert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=training) - logits = self.classifier(inputs=pooled_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "rembert", None) is not None: - with tf.name_scope(self.rembert.name): - self.rembert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RemBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - REMBERT_START_DOCSTRING, -) -class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss): - def __init__(self, config: RemBertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.rembert = TFRemBertMainLayer(config, name="rembert") - self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob) - self.classifier = keras.layers.Dense( - units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @add_code_sample_docstrings( - checkpoint="google/rembert", - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None - flat_attention_mask = ( - tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None - ) - flat_token_type_ids = ( - tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None - ) - flat_position_ids = ( - tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None - ) - flat_inputs_embeds = ( - tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - outputs = self.rembert( - input_ids=flat_input_ids, - attention_mask=flat_attention_mask, - token_type_ids=flat_token_type_ids, - position_ids=flat_position_ids, - head_mask=head_mask, - inputs_embeds=flat_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=training) - logits = self.classifier(inputs=pooled_output) - reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "rembert", None) is not None: - with tf.name_scope(self.rembert.name): - self.rembert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RemBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - REMBERT_START_DOCSTRING, -) -class TFRemBertForTokenClassification(TFRemBertPreTrainedModel, TFTokenClassificationLoss): - def __init__(self, config: RemBertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.rembert = TFRemBertMainLayer(config, name="rembert", add_pooling_layer=False) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="google/rembert", - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.rembert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(inputs=sequence_output, training=training) - logits = self.classifier(inputs=sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "rembert", None) is not None: - with tf.name_scope(self.rembert.name): - self.rembert.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - REMBERT_START_DOCSTRING, -) -class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnsweringLoss): - def __init__(self, config: RemBertConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.rembert = TFRemBertMainLayer(config, add_pooling_layer=False, name="rembert") - self.qa_outputs = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="google/rembert", - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.rembert( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.qa_outputs(inputs=sequence_output) - start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) - start_logits = tf.squeeze(input=start_logits, axis=-1) - end_logits = tf.squeeze(input=end_logits, axis=-1) - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "rembert", None) is not None: - with tf.name_scope(self.rembert.name): - self.rembert.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFRemBertForCausalLM", - "TFRemBertForMaskedLM", - "TFRemBertForMultipleChoice", - "TFRemBertForQuestionAnswering", - "TFRemBertForSequenceClassification", - "TFRemBertForTokenClassification", - "TFRemBertLayer", - "TFRemBertModel", - "TFRemBertPreTrainedModel", -] diff --git a/src/transformers/models/resnet/modeling_flax_resnet.py b/src/transformers/models/resnet/modeling_flax_resnet.py deleted file mode 100644 index a2a9418b7cf2..000000000000 --- a/src/transformers/models/resnet/modeling_flax_resnet.py +++ /dev/null @@ -1,704 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.traverse_util import flatten_dict, unflatten_dict - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutputWithNoAttention, - FlaxBaseModelOutputWithPoolingAndNoAttention, - FlaxImageClassifierOutputWithNoAttention, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward -from .configuration_resnet import ResNetConfig - - -RESNET_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`ResNetConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - - -RESNET_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`jax.numpy.float32` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`AutoImageProcessor.__call__`] for details. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class Identity(nn.Module): - """Identity function.""" - - @nn.compact - def __call__(self, x, **kwargs): - return x - - -class FlaxResNetConvLayer(nn.Module): - out_channels: int - kernel_size: int = 3 - stride: int = 1 - activation: Optional[str] = "relu" - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.convolution = nn.Conv( - self.out_channels, - kernel_size=(self.kernel_size, self.kernel_size), - strides=self.stride, - padding=self.kernel_size // 2, - dtype=self.dtype, - use_bias=False, - kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=self.dtype), - ) - self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) - self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity() - - def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - hidden_state = self.convolution(x) - hidden_state = self.normalization(hidden_state, use_running_average=deterministic) - hidden_state = self.activation_func(hidden_state) - return hidden_state - - -class FlaxResNetEmbeddings(nn.Module): - """ - ResNet Embeddings (stem) composed of a single aggressive convolution. - """ - - config: ResNetConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.embedder = FlaxResNetConvLayer( - self.config.embedding_size, - kernel_size=7, - stride=2, - activation=self.config.hidden_act, - dtype=self.dtype, - ) - - self.max_pool = partial(nn.max_pool, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) - - def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - num_channels = pixel_values.shape[-1] - if num_channels != self.config.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - embedding = self.embedder(pixel_values, deterministic=deterministic) - embedding = self.max_pool(embedding) - return embedding - - -class FlaxResNetShortCut(nn.Module): - """ - ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to - downsample the input using `stride=2`. - """ - - out_channels: int - stride: int = 2 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.convolution = nn.Conv( - self.out_channels, - kernel_size=(1, 1), - strides=self.stride, - use_bias=False, - kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), - dtype=self.dtype, - ) - self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) - - def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - hidden_state = self.convolution(x) - hidden_state = self.normalization(hidden_state, use_running_average=deterministic) - return hidden_state - - -class FlaxResNetBasicLayerCollection(nn.Module): - out_channels: int - stride: int = 1 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.layer = [ - FlaxResNetConvLayer(self.out_channels, stride=self.stride, dtype=self.dtype), - FlaxResNetConvLayer(self.out_channels, activation=None, dtype=self.dtype), - ] - - def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - for layer in self.layer: - hidden_state = layer(hidden_state, deterministic=deterministic) - return hidden_state - - -class FlaxResNetBasicLayer(nn.Module): - """ - A classic ResNet's residual layer composed by two `3x3` convolutions. - """ - - in_channels: int - out_channels: int - stride: int = 1 - activation: Optional[str] = "relu" - dtype: jnp.dtype = jnp.float32 - - def setup(self): - should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 - self.shortcut = ( - FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype) - if should_apply_shortcut - else None - ) - self.layer = FlaxResNetBasicLayerCollection( - out_channels=self.out_channels, - stride=self.stride, - dtype=self.dtype, - ) - self.activation_func = ACT2FN[self.activation] - - def __call__(self, hidden_state, deterministic: bool = True): - residual = hidden_state - hidden_state = self.layer(hidden_state, deterministic=deterministic) - - if self.shortcut is not None: - residual = self.shortcut(residual, deterministic=deterministic) - hidden_state += residual - - hidden_state = self.activation_func(hidden_state) - return hidden_state - - -class FlaxResNetBottleNeckLayerCollection(nn.Module): - out_channels: int - stride: int = 1 - activation: Optional[str] = "relu" - reduction: int = 4 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - reduces_channels = self.out_channels // self.reduction - - self.layer = [ - FlaxResNetConvLayer(reduces_channels, kernel_size=1, dtype=self.dtype, name="0"), - FlaxResNetConvLayer(reduces_channels, stride=self.stride, dtype=self.dtype, name="1"), - FlaxResNetConvLayer(self.out_channels, kernel_size=1, activation=None, dtype=self.dtype, name="2"), - ] - - def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - for layer in self.layer: - hidden_state = layer(hidden_state, deterministic=deterministic) - return hidden_state - - -class FlaxResNetBottleNeckLayer(nn.Module): - """ - A classic ResNet's bottleneck layer composed by three `3x3` convolutions. The first `1x1` convolution reduces the - input by a factor of `reduction` in order to make the second `3x3` convolution faster. The last `1x1` convolution - remaps the reduced features to `out_channels`. - """ - - in_channels: int - out_channels: int - stride: int = 1 - activation: Optional[str] = "relu" - reduction: int = 4 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 - self.shortcut = ( - FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype) - if should_apply_shortcut - else None - ) - - self.layer = FlaxResNetBottleNeckLayerCollection( - self.out_channels, - stride=self.stride, - activation=self.activation, - reduction=self.reduction, - dtype=self.dtype, - ) - - self.activation_func = ACT2FN[self.activation] - - def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - residual = hidden_state - - if self.shortcut is not None: - residual = self.shortcut(residual, deterministic=deterministic) - hidden_state = self.layer(hidden_state, deterministic) - hidden_state += residual - hidden_state = self.activation_func(hidden_state) - return hidden_state - - -class FlaxResNetStageLayersCollection(nn.Module): - """ - A ResNet stage composed by stacked layers. - """ - - config: ResNetConfig - in_channels: int - out_channels: int - stride: int = 2 - depth: int = 2 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - layer = FlaxResNetBottleNeckLayer if self.config.layer_type == "bottleneck" else FlaxResNetBasicLayer - - layers = [ - # downsampling is done in the first layer with stride of 2 - layer( - self.in_channels, - self.out_channels, - stride=self.stride, - activation=self.config.hidden_act, - dtype=self.dtype, - name="0", - ), - ] - - for i in range(self.depth - 1): - layers.append( - layer( - self.out_channels, - self.out_channels, - activation=self.config.hidden_act, - dtype=self.dtype, - name=str(i + 1), - ) - ) - - self.layers = layers - - def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - hidden_state = x - for layer in self.layers: - hidden_state = layer(hidden_state, deterministic=deterministic) - return hidden_state - - -class FlaxResNetStage(nn.Module): - """ - A ResNet stage composed by stacked layers. - """ - - config: ResNetConfig - in_channels: int - out_channels: int - stride: int = 2 - depth: int = 2 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.layers = FlaxResNetStageLayersCollection( - self.config, - in_channels=self.in_channels, - out_channels=self.out_channels, - stride=self.stride, - depth=self.depth, - dtype=self.dtype, - ) - - def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - return self.layers(x, deterministic=deterministic) - - -class FlaxResNetStageCollection(nn.Module): - config: ResNetConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:]) - stages = [ - FlaxResNetStage( - self.config, - self.config.embedding_size, - self.config.hidden_sizes[0], - stride=2 if self.config.downsample_in_first_stage else 1, - depth=self.config.depths[0], - dtype=self.dtype, - name="0", - ) - ] - - for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])): - stages.append( - FlaxResNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1)) - ) - - self.stages = stages - - def __call__( - self, - hidden_state: jnp.ndarray, - output_hidden_states: bool = False, - deterministic: bool = True, - ) -> FlaxBaseModelOutputWithNoAttention: - hidden_states = () if output_hidden_states else None - - for stage_module in self.stages: - if output_hidden_states: - hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) - - hidden_state = stage_module(hidden_state, deterministic=deterministic) - - return hidden_state, hidden_states - - -class FlaxResNetEncoder(nn.Module): - config: ResNetConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.stages = FlaxResNetStageCollection(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_state: jnp.ndarray, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ) -> FlaxBaseModelOutputWithNoAttention: - hidden_state, hidden_states = self.stages( - hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic - ) - - if output_hidden_states: - hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) - - if not return_dict: - return tuple(v for v in [hidden_state, hidden_states] if v is not None) - - return FlaxBaseModelOutputWithNoAttention( - last_hidden_state=hidden_state, - hidden_states=hidden_states, - ) - - -class FlaxResNetPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ResNetConfig - base_model_prefix = "resnet" - main_input_name = "pixel_values" - module_class: nn.Module = None - - def __init__( - self, - config: ResNetConfig, - input_shape=(1, 224, 224, 3), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - if input_shape is None: - input_shape = (1, config.image_size, config.image_size, config.num_channels) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - pixel_values = jnp.zeros(input_shape, dtype=self.dtype) - - rngs = {"params": rng} - - random_params = self.module.init(rngs, pixel_values, return_dict=False) - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) - def __call__( - self, - pixel_values, - params: Optional[dict] = None, - train: bool = False, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) - - # Handle any PRNG if needed - rngs = {} - - return self.module.apply( - { - "params": params["params"] if params is not None else self.params["params"], - "batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"], - }, - jnp.array(pixel_values, dtype=jnp.float32), - not train, - output_hidden_states, - return_dict, - rngs=rngs, - mutable=["batch_stats"] if train else False, # Returning tuple with batch_stats only when train is True - ) - - -class FlaxResNetModule(nn.Module): - config: ResNetConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.embedder = FlaxResNetEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxResNetEncoder(self.config, dtype=self.dtype) - - # Adaptive average pooling used in resnet - self.pooler = partial( - nn.avg_pool, - padding=((0, 0), (0, 0)), - ) - - def __call__( - self, - pixel_values, - deterministic: bool = True, - output_hidden_states: bool = False, - return_dict: bool = True, - ) -> FlaxBaseModelOutputWithPoolingAndNoAttention: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - embedding_output = self.embedder(pixel_values, deterministic=deterministic) - - encoder_outputs = self.encoder( - embedding_output, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - last_hidden_state = encoder_outputs[0] - - pooled_output = self.pooler( - last_hidden_state, - window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]), - strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]), - ).transpose(0, 3, 1, 2) - - last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2) - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return FlaxBaseModelOutputWithPoolingAndNoAttention( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - ) - - -@add_start_docstrings( - "The bare ResNet model outputting raw features without any specific head on top.", - RESNET_START_DOCSTRING, -) -class FlaxResNetModel(FlaxResNetPreTrainedModel): - module_class = FlaxResNetModule - - -FLAX_VISION_MODEL_DOCSTRING = """ - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, FlaxResNetModel - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") - >>> model = FlaxResNetModel.from_pretrained("microsoft/resnet-50") - >>> inputs = image_processor(images=image, return_tensors="np") - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.last_hidden_state - ``` -""" - -overwrite_call_docstring(FlaxResNetModel, FLAX_VISION_MODEL_DOCSTRING) -append_replace_return_docstrings( - FlaxResNetModel, output_type=FlaxBaseModelOutputWithPoolingAndNoAttention, config_class=ResNetConfig -) - - -class FlaxResNetClassifierCollection(nn.Module): - config: ResNetConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1") - - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - return self.classifier(x) - - -class FlaxResNetForImageClassificationModule(nn.Module): - config: ResNetConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.resnet = FlaxResNetModule(config=self.config, dtype=self.dtype) - - if self.config.num_labels > 0: - self.classifier = FlaxResNetClassifierCollection(self.config, dtype=self.dtype) - else: - self.classifier = Identity() - - def __call__( - self, - pixel_values=None, - deterministic: bool = True, - output_hidden_states=None, - return_dict=None, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.resnet( - pixel_values, - deterministic=deterministic, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs.pooler_output if return_dict else outputs[1] - - logits = self.classifier(pooled_output[:, :, 0, 0]) - - if not return_dict: - output = (logits,) + outputs[2:] - return output - - return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states) - - -@add_start_docstrings( - """ - ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for - ImageNet. - """, - RESNET_START_DOCSTRING, -) -class FlaxResNetForImageClassification(FlaxResNetPreTrainedModel): - module_class = FlaxResNetForImageClassificationModule - - -FLAX_VISION_CLASSIF_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoImageProcessor, FlaxResNetForImageClassification - >>> from PIL import Image - >>> import jax - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") - >>> model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50") - - >>> inputs = image_processor(images=image, return_tensors="np") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) - >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) - ``` -""" - -overwrite_call_docstring(FlaxResNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) -append_replace_return_docstrings( - FlaxResNetForImageClassification, output_type=FlaxImageClassifierOutputWithNoAttention, config_class=ResNetConfig -) - - -__all__ = ["FlaxResNetForImageClassification", "FlaxResNetModel", "FlaxResNetPreTrainedModel"] diff --git a/src/transformers/models/resnet/modeling_tf_resnet.py b/src/transformers/models/resnet/modeling_tf_resnet.py deleted file mode 100644 index f7c415f97b05..000000000000 --- a/src/transformers/models/resnet/modeling_tf_resnet.py +++ /dev/null @@ -1,596 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow ResNet model.""" - -from typing import Optional, Union - -import tensorflow as tf - -from ...activations_tf import ACT2FN -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithNoAttention, - TFBaseModelOutputWithPoolingAndNoAttention, - TFImageClassifierOutputWithNoAttention, -) -from ...modeling_tf_utils import ( - TFPreTrainedModel, - TFSequenceClassificationLoss, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_resnet import ResNetConfig - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "ResNetConfig" - -# Base docstring -_CHECKPOINT_FOR_DOC = "microsoft/resnet-50" -_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7] - -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "microsoft/resnet-50" -_IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat" - - -class TFResNetConvLayer(keras.layers.Layer): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - stride: int = 1, - activation: str = "relu", - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.pad_value = kernel_size // 2 - self.conv = keras.layers.Conv2D( - out_channels, kernel_size=kernel_size, strides=stride, padding="valid", use_bias=False, name="convolution" - ) - # Use same default momentum and epsilon as PyTorch equivalent - self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") - self.activation = ACT2FN[activation] if activation is not None else keras.layers.Activation("linear") - self.in_channels = in_channels - self.out_channels = out_channels - - def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: - # Pad to match that done in the PyTorch Conv2D model - height_pad = width_pad = (self.pad_value, self.pad_value) - hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)]) - hidden_state = self.conv(hidden_state) - return hidden_state - - def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_state = self.convolution(hidden_state) - hidden_state = self.normalization(hidden_state, training=training) - hidden_state = self.activation(hidden_state) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build([None, None, None, self.in_channels]) - if getattr(self, "normalization", None) is not None: - with tf.name_scope(self.normalization.name): - self.normalization.build([None, None, None, self.out_channels]) - - -class TFResNetEmbeddings(keras.layers.Layer): - """ - ResNet Embeddings (stem) composed of a single aggressive convolution. - """ - - def __init__(self, config: ResNetConfig, **kwargs) -> None: - super().__init__(**kwargs) - self.embedder = TFResNetConvLayer( - config.num_channels, - config.embedding_size, - kernel_size=7, - stride=2, - activation=config.hidden_act, - name="embedder", - ) - self.pooler = keras.layers.MaxPool2D(pool_size=3, strides=2, padding="valid", name="pooler") - self.num_channels = config.num_channels - - def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: - _, _, _, num_channels = shape_list(pixel_values) - if tf.executing_eagerly() and num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - hidden_state = pixel_values - hidden_state = self.embedder(hidden_state) - hidden_state = tf.pad(hidden_state, [[0, 0], [1, 1], [1, 1], [0, 0]]) - hidden_state = self.pooler(hidden_state) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embedder", None) is not None: - with tf.name_scope(self.embedder.name): - self.embedder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -class TFResNetShortCut(keras.layers.Layer): - """ - ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to - downsample the input using `stride=2`. - """ - - def __init__(self, in_channels: int, out_channels: int, stride: int = 2, **kwargs) -> None: - super().__init__(**kwargs) - self.convolution = keras.layers.Conv2D( - out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution" - ) - # Use same default momentum and epsilon as PyTorch equivalent - self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") - self.in_channels = in_channels - self.out_channels = out_channels - - def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_state = x - hidden_state = self.convolution(hidden_state) - hidden_state = self.normalization(hidden_state, training=training) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "convolution", None) is not None: - with tf.name_scope(self.convolution.name): - self.convolution.build([None, None, None, self.in_channels]) - if getattr(self, "normalization", None) is not None: - with tf.name_scope(self.normalization.name): - self.normalization.build([None, None, None, self.out_channels]) - - -class TFResNetBasicLayer(keras.layers.Layer): - """ - A classic ResNet's residual layer composed by two `3x3` convolutions. - """ - - def __init__( - self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu", **kwargs - ) -> None: - super().__init__(**kwargs) - should_apply_shortcut = in_channels != out_channels or stride != 1 - self.conv1 = TFResNetConvLayer(in_channels, out_channels, stride=stride, name="layer.0") - self.conv2 = TFResNetConvLayer(out_channels, out_channels, activation=None, name="layer.1") - self.shortcut = ( - TFResNetShortCut(in_channels, out_channels, stride=stride, name="shortcut") - if should_apply_shortcut - else keras.layers.Activation("linear", name="shortcut") - ) - self.activation = ACT2FN[activation] - - def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: - residual = hidden_state - hidden_state = self.conv1(hidden_state, training=training) - hidden_state = self.conv2(hidden_state, training=training) - residual = self.shortcut(residual, training=training) - hidden_state += residual - hidden_state = self.activation(hidden_state) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv1", None) is not None: - with tf.name_scope(self.conv1.name): - self.conv1.build(None) - if getattr(self, "conv2", None) is not None: - with tf.name_scope(self.conv2.name): - self.conv2.build(None) - if getattr(self, "shortcut", None) is not None: - with tf.name_scope(self.shortcut.name): - self.shortcut.build(None) - - -class TFResNetBottleNeckLayer(keras.layers.Layer): - """ - A classic ResNet's bottleneck layer composed by three `3x3` convolutions. - - The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` - convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - stride: int = 1, - activation: str = "relu", - reduction: int = 4, - **kwargs, - ) -> None: - super().__init__(**kwargs) - should_apply_shortcut = in_channels != out_channels or stride != 1 - reduces_channels = out_channels // reduction - self.conv0 = TFResNetConvLayer(in_channels, reduces_channels, kernel_size=1, name="layer.0") - self.conv1 = TFResNetConvLayer(reduces_channels, reduces_channels, stride=stride, name="layer.1") - self.conv2 = TFResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None, name="layer.2") - self.shortcut = ( - TFResNetShortCut(in_channels, out_channels, stride=stride, name="shortcut") - if should_apply_shortcut - else keras.layers.Activation("linear", name="shortcut") - ) - self.activation = ACT2FN[activation] - - def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: - residual = hidden_state - hidden_state = self.conv0(hidden_state, training=training) - hidden_state = self.conv1(hidden_state, training=training) - hidden_state = self.conv2(hidden_state, training=training) - residual = self.shortcut(residual, training=training) - hidden_state += residual - hidden_state = self.activation(hidden_state) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv0", None) is not None: - with tf.name_scope(self.conv0.name): - self.conv0.build(None) - if getattr(self, "conv1", None) is not None: - with tf.name_scope(self.conv1.name): - self.conv1.build(None) - if getattr(self, "conv2", None) is not None: - with tf.name_scope(self.conv2.name): - self.conv2.build(None) - if getattr(self, "shortcut", None) is not None: - with tf.name_scope(self.shortcut.name): - self.shortcut.build(None) - - -class TFResNetStage(keras.layers.Layer): - """ - A ResNet stage composed of stacked layers. - """ - - def __init__( - self, config: ResNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs - ) -> None: - super().__init__(**kwargs) - - layer = TFResNetBottleNeckLayer if config.layer_type == "bottleneck" else TFResNetBasicLayer - - layers = [layer(in_channels, out_channels, stride=stride, activation=config.hidden_act, name="layers.0")] - layers += [ - layer(out_channels, out_channels, activation=config.hidden_act, name=f"layers.{i + 1}") - for i in range(depth - 1) - ] - self.stage_layers = layers - - def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: - for layer in self.stage_layers: - hidden_state = layer(hidden_state, training=training) - return hidden_state - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "stage_layers", None) is not None: - for layer in self.stage_layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFResNetEncoder(keras.layers.Layer): - def __init__(self, config: ResNetConfig, **kwargs) -> None: - super().__init__(**kwargs) - # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input - self.stages = [ - TFResNetStage( - config, - config.embedding_size, - config.hidden_sizes[0], - stride=2 if config.downsample_in_first_stage else 1, - depth=config.depths[0], - name="stages.0", - ) - ] - for i, (in_channels, out_channels, depth) in enumerate( - zip(config.hidden_sizes, config.hidden_sizes[1:], config.depths[1:]) - ): - self.stages.append(TFResNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i + 1}")) - - def call( - self, - hidden_state: tf.Tensor, - output_hidden_states: bool = False, - return_dict: bool = True, - training: bool = False, - ) -> TFBaseModelOutputWithNoAttention: - hidden_states = () if output_hidden_states else None - - for stage_module in self.stages: - if output_hidden_states: - hidden_states = hidden_states + (hidden_state,) - - hidden_state = stage_module(hidden_state, training=training) - - if output_hidden_states: - hidden_states = hidden_states + (hidden_state,) - - if not return_dict: - return tuple(v for v in [hidden_state, hidden_states] if v is not None) - - return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "stages", None) is not None: - for layer in self.stages: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFResNetPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ResNetConfig - base_model_prefix = "resnet" - main_input_name = "pixel_values" - - @property - def input_signature(self): - return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 224, 224), dtype=tf.float32)} - - -RESNET_START_DOCSTRING = r""" - This model is a TensorFlow - [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a - regular TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`ResNetConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -RESNET_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`ConvNextImageProcessor.__call__`] for details. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@keras_serializable -class TFResNetMainLayer(keras.layers.Layer): - config_class = ResNetConfig - - def __init__(self, config: ResNetConfig, **kwargs) -> None: - super().__init__(**kwargs) - self.config = config - self.embedder = TFResNetEmbeddings(config, name="embedder") - self.encoder = TFResNetEncoder(config, name="encoder") - self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True) - - @unpack_inputs - def call( - self, - pixel_values: tf.Tensor, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # TF 2.0 image layers can't use NCHW format when running on CPU. - # We transpose to NHWC format and then transpose back after the full forward pass. - # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) - pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1]) - embedding_output = self.embedder(pixel_values, training=training) - - encoder_outputs = self.encoder( - embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training - ) - - last_hidden_state = encoder_outputs[0] - - pooled_output = self.pooler(last_hidden_state) - - # Transpose all the outputs to the NCHW format - # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width) - last_hidden_state = tf.transpose(last_hidden_state, (0, 3, 1, 2)) - pooled_output = tf.transpose(pooled_output, (0, 3, 1, 2)) - hidden_states = () - for hidden_state in encoder_outputs[1:]: - hidden_states = hidden_states + tuple(tf.transpose(h, (0, 3, 1, 2)) for h in hidden_state) - - if not return_dict: - return (last_hidden_state, pooled_output) + hidden_states - - hidden_states = hidden_states if output_hidden_states else None - - return TFBaseModelOutputWithPoolingAndNoAttention( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embedder", None) is not None: - with tf.name_scope(self.embedder.name): - self.embedder.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - - -@add_start_docstrings( - "The bare ResNet model outputting raw features without any specific head on top.", - RESNET_START_DOCSTRING, -) -class TFResNetModel(TFResNetPreTrainedModel): - def __init__(self, config: ResNetConfig, **kwargs) -> None: - super().__init__(config, **kwargs) - self.resnet = TFResNetMainLayer(config=config, name="resnet") - - @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPoolingAndNoAttention, - config_class=_CONFIG_FOR_DOC, - modality="vision", - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - @unpack_inputs - def call( - self, - pixel_values: tf.Tensor, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - resnet_outputs = self.resnet( - pixel_values=pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return resnet_outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "resnet", None) is not None: - with tf.name_scope(self.resnet.name): - self.resnet.build(None) - - -@add_start_docstrings( - """ - ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for - ImageNet. - """, - RESNET_START_DOCSTRING, -) -class TFResNetForImageClassification(TFResNetPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: ResNetConfig, **kwargs) -> None: - super().__init__(config, **kwargs) - self.num_labels = config.num_labels - self.resnet = TFResNetMainLayer(config, name="resnet") - # classification head - self.classifier_layer = ( - keras.layers.Dense(config.num_labels, name="classifier.1") - if config.num_labels > 0 - else keras.layers.Activation("linear", name="classifier.1") - ) - self.config = config - - def classifier(self, x: tf.Tensor) -> tf.Tensor: - x = keras.layers.Flatten()(x) - logits = self.classifier_layer(x) - return logits - - @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=TFImageClassifierOutputWithNoAttention, - config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, - ) - @unpack_inputs - def call( - self, - pixel_values: Optional[tf.Tensor] = None, - labels: Optional[tf.Tensor] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[tuple[tf.Tensor], TFImageClassifierOutputWithNoAttention]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.resnet( - pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training - ) - - pooled_output = outputs.pooler_output if return_dict else outputs[1] - - logits = self.classifier(pooled_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return (loss,) + output if loss is not None else output - - return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "resnet", None) is not None: - with tf.name_scope(self.resnet.name): - self.resnet.build(None) - if getattr(self, "classifier_layer", None) is not None: - with tf.name_scope(self.classifier_layer.name): - self.classifier_layer.build([None, None, self.config.hidden_sizes[-1]]) - - -__all__ = ["TFResNetForImageClassification", "TFResNetModel", "TFResNetPreTrainedModel"] diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py deleted file mode 100644 index 3b46c0fa682f..000000000000 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ /dev/null @@ -1,1500 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen import partitioning as nn_partitioning -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxBaseModelOutputWithPooling, - FlaxBaseModelOutputWithPoolingAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxMaskedLMOutput, - FlaxMultipleChoiceModelOutput, - FlaxQuestionAnsweringModelOutput, - FlaxSequenceClassifierOutput, - FlaxTokenClassifierOutput, -) -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_roberta import RobertaConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "FacebookAI/roberta-base" -_CONFIG_FOR_DOC = "RobertaConfig" - -remat = nn_partitioning.remat - - -def create_position_ids_from_input_ids(input_ids, padding_idx): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - input_ids: jnp.ndarray - padding_idx: int - - Returns: jnp.ndarray - """ - # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. - mask = (input_ids != padding_idx).astype("i4") - - if mask.ndim > 2: - mask = mask.reshape((-1, mask.shape[-1])) - incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask - incremental_indices = incremental_indices.reshape(input_ids.shape) - else: - incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask - - return incremental_indices.astype("i4") + padding_idx - - -ROBERTA_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`RobertaConfig`]): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -ROBERTA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - head_mask (`numpy.ndarray` of shape `({0})`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta -class FlaxRobertaEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" - - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.position_embeddings = nn.Embed( - self.config.max_position_embeddings, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.token_type_embeddings = nn.Embed( - self.config.type_vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): - # Embed - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) - position_embeds = self.position_embeddings(position_ids.astype("i4")) - token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) - - # Sum all embeddings - hidden_states = inputs_embeds + token_type_embeddings + position_embeds - - # Layer Norm - hidden_states = self.LayerNorm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta -class FlaxRobertaSelfAttention(nn.Module): - config: RobertaConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.head_dim = self.config.hidden_size // self.config.num_attention_heads - if self.config.hidden_size % self.config.num_attention_heads != 0: - raise ValueError( - "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " - " : {self.config.num_attention_heads}" - ) - - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) - - @nn.compact - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic=True, - output_attentions: bool = False, - ): - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.query(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.key(key_value_states) - value_states = self.value(key_value_states) - else: - # self_attention - key_states = self.key(hidden_states) - value_states = self.value(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta -class FlaxRobertaSelfOutput(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, input_tensor, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta -class FlaxRobertaAttention(nn.Module): - config: RobertaConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype) - self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states=None, - init_cache=False, - deterministic=True, - output_attentions: bool = False, - ): - # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) - # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable - # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - attn_outputs = self.self( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=key_value_states, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_outputs[1],) - - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta -class FlaxRobertaIntermediate(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.activation = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta -class FlaxRobertaOutput(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states, attention_output, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + attention_output) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta -class FlaxRobertaLayer(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) - self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype) - self.output = FlaxRobertaOutput(self.config, dtype=self.dtype) - if self.config.add_cross_attention: - self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - ): - # Self Attention - attention_outputs = self.attention( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = attention_outputs[0] - - # Cross-Attention Block - if encoder_hidden_states is not None: - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=encoder_hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = cross_attention_outputs[0] - - hidden_states = self.intermediate(attention_output) - hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[1],) - if encoder_hidden_states is not None: - outputs += (cross_attention_outputs[1],) - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta -class FlaxRobertaLayerCollection(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - if self.gradient_checkpointing: - FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7)) - self.layers = [ - FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - else: - self.layers = [ - FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - # Check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.shape[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for " - f" {head_mask.shape[0]}." - ) - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer( - hidden_states, - attention_mask, - head_mask[i] if head_mask is not None else None, - encoder_hidden_states, - encoder_attention_mask, - init_cache, - deterministic, - output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta -class FlaxRobertaEncoder(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.layer = FlaxRobertaLayerCollection( - self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.layer( - hidden_states, - attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta -class FlaxRobertaPooler(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - cls_hidden_state = hidden_states[:, 0] - cls_hidden_state = self.dense(cls_hidden_state) - return nn.tanh(cls_hidden_state) - - -class FlaxRobertaLMHead(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.decoder = nn.Dense( - self.config.vocab_size, - dtype=self.dtype, - use_bias=False, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) - - def __call__(self, hidden_states, shared_embedding=None): - hidden_states = self.dense(hidden_states) - hidden_states = ACT2FN["gelu"](hidden_states) - hidden_states = self.layer_norm(hidden_states) - - if shared_embedding is not None: - hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - hidden_states = self.decoder(hidden_states) - - bias = jnp.asarray(self.bias, self.dtype) - hidden_states += bias - return hidden_states - - -class FlaxRobertaClassificationHead(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(rate=classifier_dropout) - self.out_proj = nn.Dense( - self.config.num_labels, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - def __call__(self, hidden_states, deterministic=True): - hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.dense(hidden_states) - hidden_states = nn.tanh(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RobertaConfig - base_model_prefix = "roberta" - - module_class: nn.Module = None - - def __init__( - self, - config: RobertaConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - gradient_checkpointing: bool = False, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing - def enable_gradient_checkpointing(self): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=True, - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - token_type_ids = jnp.ones_like(input_ids) - position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) - attention_mask = jnp.ones_like(input_ids) - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False - ) - - random_params = module_init_outputs["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length), dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - past_key_values: Optional[dict] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # init input tensors if not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - if position_ids is None: - position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - if head_mask is None: - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - if self.config.add_cross_attention: - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed - # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be - # changed by FlaxRobertaAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - else: - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - ) - - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta -class FlaxRobertaModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - add_pooling_layer: bool = True - gradient_checkpointing: bool = False - - def setup(self): - self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxRobertaEncoder( - self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # make sure `token_type_ids` is correctly initialized when not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - # make sure `position_ids` is correctly initialized when not passed - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - hidden_states = self.embeddings( - input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic - ) - outputs = self.encoder( - hidden_states, - attention_mask, - head_mask=head_mask, - deterministic=deterministic, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - pooled = self.pooler(hidden_states) if self.add_pooling_layer else None - - if not return_dict: - # if pooled is None, don't return it - if pooled is None: - return (hidden_states,) + outputs[1:] - return (hidden_states, pooled) + outputs[1:] - - return FlaxBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=hidden_states, - pooler_output=pooled, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", - ROBERTA_START_DOCSTRING, -) -class FlaxRobertaModel(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaModule - - -append_call_sample_docstring(FlaxRobertaModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC) - - -class FlaxRobertaForMaskedLMModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxMaskedLMOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING) -class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaForMaskedLMModule - - -append_call_sample_docstring( - FlaxRobertaForMaskedLM, - _CHECKPOINT_FOR_DOC, - FlaxBaseModelOutputWithPooling, - _CONFIG_FOR_DOC, - mask="", -) - - -class FlaxRobertaForSequenceClassificationModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - logits = self.classifier(sequence_output, deterministic=deterministic) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - ROBERTA_START_DOCSTRING, -) -class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaForSequenceClassificationModule - - -append_call_sample_docstring( - FlaxRobertaForSequenceClassification, - _CHECKPOINT_FOR_DOC, - FlaxSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->Roberta, with self.bert->self.roberta -class FlaxRobertaForMultipleChoiceModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.classifier = nn.Dense(1, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - num_choices = input_ids.shape[1] - input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None - token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None - position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None - - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) - - reshaped_logits = logits.reshape(-1, num_choices) - - if not return_dict: - return (reshaped_logits,) + outputs[2:] - - return FlaxMultipleChoiceModelOutput( - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - ROBERTA_START_DOCSTRING, -) -class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaForMultipleChoiceModule - - -overwrite_call_docstring( - FlaxRobertaForMultipleChoice, ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") -) -append_call_sample_docstring( - FlaxRobertaForMultipleChoice, - _CHECKPOINT_FOR_DOC, - FlaxMultipleChoiceModelOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->Roberta, with self.bert->self.roberta -class FlaxRobertaForTokenClassificationModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(rate=classifier_dropout) - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.classifier(hidden_states) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxTokenClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - ROBERTA_START_DOCSTRING, -) -class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaForTokenClassificationModule - - -append_call_sample_docstring( - FlaxRobertaForTokenClassification, - _CHECKPOINT_FOR_DOC, - FlaxTokenClassifierOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->Roberta, with self.bert->self.roberta -class FlaxRobertaForQuestionAnsweringModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - logits = self.qa_outputs(hidden_states) - start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - return (start_logits, end_logits) + outputs[1:] - - return FlaxQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - ROBERTA_START_DOCSTRING, -) -class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaForQuestionAnsweringModule - - -append_call_sample_docstring( - FlaxRobertaForQuestionAnswering, - _CHECKPOINT_FOR_DOC, - FlaxQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxRobertaForCausalLMModule(nn.Module): - config: RobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxRobertaModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - token_type_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxCausalLMOutputWithCrossAttentions( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - """ - Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for - autoregressive tasks. - """, - ROBERTA_START_DOCSTRING, -) -class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel): - module_class = FlaxRobertaForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyway. - # Thus, we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxRobertaForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutputWithCrossAttentions, - _CONFIG_FOR_DOC, -) - - -__all__ = [ - "FlaxRobertaForCausalLM", - "FlaxRobertaForMaskedLM", - "FlaxRobertaForMultipleChoice", - "FlaxRobertaForQuestionAnswering", - "FlaxRobertaForSequenceClassification", - "FlaxRobertaForTokenClassification", - "FlaxRobertaModel", - "FlaxRobertaPreTrainedModel", -] diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py deleted file mode 100644 index c5c56b85d5f3..000000000000 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ /dev/null @@ -1,1782 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 RoBERTa model.""" - -from __future__ import annotations - -import math -import warnings - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFBaseModelOutputWithPoolingAndCrossAttentions, - TFCausalLMOutputWithCrossAttentions, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_roberta import RobertaConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "FacebookAI/roberta-base" -_CONFIG_FOR_DOC = "RobertaConfig" - - -class TFRobertaEmbeddings(keras.layers.Layer): - """ - Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. - """ - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.padding_idx = 1 - self.config = config - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding - symbols are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - input_ids: tf.Tensor - Returns: tf.Tensor - """ - mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) - incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask - - return incremental_indices + self.padding_idx - - def call( - self, - input_ids=None, - position_ids=None, - token_type_ids=None, - inputs_embeds=None, - past_key_values_length=0, - training=False, - ): - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - if input_ids is not None: - # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = self.create_position_ids_from_input_ids( - input_ids=input_ids, past_key_values_length=past_key_values_length - ) - else: - position_ids = tf.expand_dims( - tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 - ) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Roberta -class TFRobertaPooler(keras.layers.Layer): - def __init__(self, config: RobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Roberta -class TFRobertaSelfAttention(keras.layers.Layer): - def __init__(self, config: RobertaConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - - self.is_decoder = config.is_decoder - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concat([past_key_value[0], key_layer], axis=2) - value_layer = tf.concat([past_key_value[1], value_layer], axis=2) - else: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFRobertaModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Roberta -class TFRobertaSelfOutput(keras.layers.Layer): - def __init__(self, config: RobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Roberta -class TFRobertaAttention(keras.layers.Layer): - def __init__(self, config: RobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFRobertaSelfAttention(config, name="self") - self.dense_output = TFRobertaSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - # add attentions (possibly with past_key_value) if we output them - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Roberta -class TFRobertaIntermediate(keras.layers.Layer): - def __init__(self, config: RobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Roberta -class TFRobertaOutput(keras.layers.Layer): - def __init__(self, config: RobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Roberta -class TFRobertaLayer(keras.layers.Layer): - def __init__(self, config: RobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFRobertaAttention(config, name="attention") - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TFRobertaAttention(config, name="crossattention") - self.intermediate = TFRobertaIntermediate(config, name="intermediate") - self.bert_output = TFRobertaOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_value: tuple[tf.Tensor] | None, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=self_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - input_tensor=attention_output, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + outputs # add attentions if we output them - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - if getattr(self, "crossattention", None) is not None: - with tf.name_scope(self.crossattention.name): - self.crossattention.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Roberta -class TFRobertaEncoder(keras.layers.Layer): - def __init__(self, config: RobertaConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layer = [TFRobertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_values: tuple[tuple[tf.Tensor]] | None, - use_cache: bool | None, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - past_key_value = past_key_values[i] if past_key_values is not None else None - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - if self.config.add_cross_attention and encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None - ) - - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFRobertaMainLayer(keras.layers.Layer): - config_class = RobertaConfig - - def __init__(self, config, add_pooling_layer=True, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.is_decoder = config.is_decoder - - self.num_hidden_layers = config.num_hidden_layers - self.initializer_range = config.initializer_range - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - self.encoder = TFRobertaEncoder(config, name="encoder") - self.pooler = TFRobertaPooler(config, name="pooler") if add_pooling_layer else None - # The embeddings must be the last declaration in order to follow the weights order - self.embeddings = TFRobertaEmbeddings(config, name="embeddings") - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - if not self.config.is_decoder: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_key_values_length = 0 - past_key_values = [None] * len(self.encoder.layer) - else: - past_key_values_length = shape_list(past_key_values[0][0])[-2] - - if attention_mask is None: - attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - training=training, - ) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(attention_mask) - - mask_seq_length = seq_length + past_key_values_length - # Copied from `modeling_tf_t5.py` - # Provided a padding mask of dimensions [batch_size, mask_seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if self.is_decoder: - seq_ids = tf.range(mask_seq_length) - causal_mask = tf.less_equal( - tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), - seq_ids[None, :, None], - ) - causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) - extended_attention_mask = causal_mask * attention_mask[:, None, :] - attention_mask_shape = shape_list(extended_attention_mask) - extended_attention_mask = tf.reshape( - extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) - ) - if past_key_values[0] is not None: - # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] - extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] - else: - extended_attention_mask = tf.reshape( - attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.is_decoder and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) - num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) - if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, - # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) - - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - - -class TFRobertaPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RobertaConfig - base_model_prefix = "roberta" - - -ROBERTA_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`RobertaConfig`]): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -ROBERTA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", - ROBERTA_START_DOCSTRING, -) -class TFRobertaModel(TFRobertaPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.roberta = TFRobertaMainLayer(config, name="roberta") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutputWithPoolingAndCrossAttentions: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - """ - outputs = self.roberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - - -class TFRobertaLMHead(keras.layers.Layer): - """Roberta Head for masked language modeling.""" - - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.act = get_tf_activation("gelu") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - - def get_output_embeddings(self): - return self.decoder - - def set_output_embeddings(self, value): - self.decoder.weight = value - self.decoder.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.layer_norm(hidden_states) - - # project back to size of vocabulary with bias - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING) -class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") - self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head") - - def get_lm_head(self): - return self.lm_head - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - mask="", - expected_output="' Paris'", - expected_loss=0.1, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] - - def __init__(self, config: RobertaConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if not config.is_decoder: - logger.warning("If you want to use `TFRobertaLMHeadModel` as a standalone, add `is_decoder=True.`") - - self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") - self.lm_head = TFRobertaLMHead(config, input_embeddings=self.roberta.embeddings, name="lm_head") - - def get_lm_head(self): - return self.lm_head - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = tf.ones(input_shape) - - # cut decoder_input_ids if past is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - outputs = self.roberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - logits = self.lm_head(hidden_states=sequence_output, training=training) - loss = None - - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithCrossAttentions( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -class TFRobertaClassificationHead(keras.layers.Layer): - """Head for sentence-level classification tasks.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.out_proj = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" - ) - self.config = config - - def call(self, features, training=False): - x = features[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x, training=training) - x = self.dense(x) - x = self.dropout(x, training=training) - x = self.out_proj(x) - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - ROBERTA_START_DOCSTRING, -) -class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") - self.classifier = TFRobertaClassificationHead(config, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="cardiffnlp/twitter-roberta-base-emotion", - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="'optimism'", - expected_loss=0.08, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.classifier(sequence_output, training=training) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - ROBERTA_START_DOCSTRING, -) -class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"lm_head"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.roberta = TFRobertaMainLayer(config, name="roberta") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - outputs = self.roberta( - flat_input_ids, - flat_attention_mask, - flat_token_type_ids, - flat_position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, training=training) - logits = self.classifier(pooled_output) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RoBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - ROBERTA_START_DOCSTRING, -) -class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="ydshieh/roberta-large-ner-english", - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']", - expected_loss=0.01, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - ROBERTA_START_DOCSTRING, -) -class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="ydshieh/roberta-base-squad2", - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="' puppet'", - expected_loss=0.86, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - - loss = None - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFRobertaForCausalLM", - "TFRobertaForMaskedLM", - "TFRobertaForMultipleChoice", - "TFRobertaForQuestionAnswering", - "TFRobertaForSequenceClassification", - "TFRobertaForTokenClassification", - "TFRobertaMainLayer", - "TFRobertaModel", - "TFRobertaPreTrainedModel", -] diff --git a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py deleted file mode 100644 index f65dc07bb165..000000000000 --- a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py +++ /dev/null @@ -1,1527 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax RoBERTa-PreLayerNorm model.""" - -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen import partitioning as nn_partitioning -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxBaseModelOutputWithPooling, - FlaxBaseModelOutputWithPoolingAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxMaskedLMOutput, - FlaxMultipleChoiceModelOutput, - FlaxQuestionAnsweringModelOutput, - FlaxSequenceClassifierOutput, - FlaxTokenClassifierOutput, -) -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_roberta_prelayernorm import RobertaPreLayerNormConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "andreasmadsen/efficient_mlm_m0.40" -_CONFIG_FOR_DOC = "RobertaPreLayerNormConfig" - -remat = nn_partitioning.remat - - -# Copied from transformers.models.roberta.modeling_flax_roberta.create_position_ids_from_input_ids -def create_position_ids_from_input_ids(input_ids, padding_idx): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - input_ids: jnp.ndarray - padding_idx: int - - Returns: jnp.ndarray - """ - # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. - mask = (input_ids != padding_idx).astype("i4") - - if mask.ndim > 2: - mask = mask.reshape((-1, mask.shape[-1])) - incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask - incremental_indices = incremental_indices.reshape(input_ids.shape) - else: - incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask - - return incremental_indices.astype("i4") + padding_idx - - -ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`RobertaPreLayerNormConfig`]): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - head_mask (`numpy.ndarray` of shape `({0})`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" - - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.position_embeddings = nn.Embed( - self.config.max_position_embeddings, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.token_type_embeddings = nn.Embed( - self.config.type_vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): - # Embed - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) - position_embeds = self.position_embeddings(position_ids.astype("i4")) - token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) - - # Sum all embeddings - hidden_states = inputs_embeds + token_type_embeddings + position_embeds - - # Layer Norm - hidden_states = self.LayerNorm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormSelfAttention(nn.Module): - config: RobertaPreLayerNormConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.head_dim = self.config.hidden_size // self.config.num_attention_heads - if self.config.hidden_size % self.config.num_attention_heads != 0: - raise ValueError( - "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " - " : {self.config.num_attention_heads}" - ) - - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) - - @nn.compact - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic=True, - output_attentions: bool = False, - ): - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.query(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.key(key_value_states) - value_states = self.value(key_value_states) - else: - # self_attention - key_states = self.key(hidden_states) - value_states = self.value(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxRobertaPreLayerNormSelfOutput(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, input_tensor, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = hidden_states + input_tensor - return hidden_states - - -class FlaxRobertaPreLayerNormAttention(nn.Module): - config: RobertaPreLayerNormConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.self = FlaxRobertaPreLayerNormSelfAttention(self.config, causal=self.causal, dtype=self.dtype) - self.output = FlaxRobertaPreLayerNormSelfOutput(self.config, dtype=self.dtype) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states=None, - init_cache=False, - deterministic=True, - output_attentions: bool = False, - ): - hidden_states_pre_layer_norm = self.LayerNorm(hidden_states) - # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) - # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable - # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - attn_outputs = self.self( - hidden_states_pre_layer_norm, - attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=key_value_states, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_outputs[1],) - - return outputs - - -class FlaxRobertaPreLayerNormIntermediate(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dense = nn.Dense( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.activation = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_states): - hidden_states = self.LayerNorm(hidden_states) - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -class FlaxRobertaPreLayerNormOutput(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, attention_output, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = hidden_states + attention_output - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormLayer(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.attention = FlaxRobertaPreLayerNormAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) - self.intermediate = FlaxRobertaPreLayerNormIntermediate(self.config, dtype=self.dtype) - self.output = FlaxRobertaPreLayerNormOutput(self.config, dtype=self.dtype) - if self.config.add_cross_attention: - self.crossattention = FlaxRobertaPreLayerNormAttention(self.config, causal=False, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - ): - # Self Attention - attention_outputs = self.attention( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = attention_outputs[0] - - # Cross-Attention Block - if encoder_hidden_states is not None: - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=encoder_hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = cross_attention_outputs[0] - - hidden_states = self.intermediate(attention_output) - hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[1],) - if encoder_hidden_states is not None: - outputs += (cross_attention_outputs[1],) - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormLayerCollection(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - if self.gradient_checkpointing: - FlaxRobertaPreLayerNormCheckpointLayer = remat(FlaxRobertaPreLayerNormLayer, static_argnums=(5, 6, 7)) - self.layers = [ - FlaxRobertaPreLayerNormCheckpointLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - else: - self.layers = [ - FlaxRobertaPreLayerNormLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - # Check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.shape[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for " - f" {head_mask.shape[0]}." - ) - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer( - hidden_states, - attention_mask, - head_mask[i] if head_mask is not None else None, - encoder_hidden_states, - encoder_attention_mask, - init_cache, - deterministic, - output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormEncoder(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.layer = FlaxRobertaPreLayerNormLayerCollection( - self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.layer( - hidden_states, - attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormPooler(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - cls_hidden_state = hidden_states[:, 0] - cls_hidden_state = self.dense(cls_hidden_state) - return nn.tanh(cls_hidden_state) - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaLMHead with Roberta->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormLMHead(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.decoder = nn.Dense( - self.config.vocab_size, - dtype=self.dtype, - use_bias=False, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) - - def __call__(self, hidden_states, shared_embedding=None): - hidden_states = self.dense(hidden_states) - hidden_states = ACT2FN["gelu"](hidden_states) - hidden_states = self.layer_norm(hidden_states) - - if shared_embedding is not None: - hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - hidden_states = self.decoder(hidden_states) - - bias = jnp.asarray(self.bias, self.dtype) - hidden_states += bias - return hidden_states - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaClassificationHead with Roberta->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormClassificationHead(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(rate=classifier_dropout) - self.out_proj = nn.Dense( - self.config.num_labels, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - def __call__(self, hidden_states, deterministic=True): - hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.dense(hidden_states) - hidden_states = nn.tanh(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaPreTrainedModel with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm -class FlaxRobertaPreLayerNormPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RobertaPreLayerNormConfig - base_model_prefix = "roberta_prelayernorm" - - module_class: nn.Module = None - - def __init__( - self, - config: RobertaPreLayerNormConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - gradient_checkpointing: bool = False, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing - def enable_gradient_checkpointing(self): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=True, - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - token_type_ids = jnp.ones_like(input_ids) - position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) - attention_mask = jnp.ones_like(input_ids) - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False - ) - - random_params = module_init_outputs["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length), dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - past_key_values: Optional[dict] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # init input tensors if not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - if position_ids is None: - position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - if head_mask is None: - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - if self.config.add_cross_attention: - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed - # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be - # changed by FlaxRobertaPreLayerNormAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - else: - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - ) - - return outputs - - -class FlaxRobertaPreLayerNormModule(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - add_pooling_layer: bool = True - gradient_checkpointing: bool = False - - def setup(self): - self.embeddings = FlaxRobertaPreLayerNormEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxRobertaPreLayerNormEncoder( - self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.pooler = FlaxRobertaPreLayerNormPooler(self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # make sure `token_type_ids` is correctly initialized when not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - # make sure `position_ids` is correctly initialized when not passed - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - hidden_states = self.embeddings( - input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic - ) - outputs = self.encoder( - hidden_states, - attention_mask, - head_mask=head_mask, - deterministic=deterministic, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - hidden_states = self.LayerNorm(hidden_states) - pooled = self.pooler(hidden_states) if self.add_pooling_layer else None - - if not return_dict: - # if pooled is None, don't return it - if pooled is None: - return (hidden_states,) + outputs[1:] - return (hidden_states, pooled) + outputs[1:] - - return FlaxBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=hidden_states, - pooler_output=pooled, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - "The bare RoBERTa-PreLayerNorm Model transformer outputting raw hidden-states without any specific head on top.", - ROBERTA_PRELAYERNORM_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaModel with Roberta->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormModel(FlaxRobertaPreLayerNormPreTrainedModel): - module_class = FlaxRobertaPreLayerNormModule - - -append_call_sample_docstring( - FlaxRobertaPreLayerNormModel, - _CHECKPOINT_FOR_DOC, - FlaxBaseModelOutputWithPooling, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMaskedLMModule with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm -class FlaxRobertaPreLayerNormForMaskedLMModule(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.lm_head = FlaxRobertaPreLayerNormLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta_prelayernorm( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.roberta_prelayernorm.variables["params"]["embeddings"]["word_embeddings"][ - "embedding" - ] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxMaskedLMOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """RoBERTa-PreLayerNorm Model with a `language modeling` head on top.""", ROBERTA_PRELAYERNORM_START_DOCSTRING -) -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMaskedLM with Roberta->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormForMaskedLM(FlaxRobertaPreLayerNormPreTrainedModel): - module_class = FlaxRobertaPreLayerNormForMaskedLMModule - - -append_call_sample_docstring( - FlaxRobertaPreLayerNormForMaskedLM, - _CHECKPOINT_FOR_DOC, - FlaxBaseModelOutputWithPooling, - _CONFIG_FOR_DOC, - mask="", -) - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForSequenceClassificationModule with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm -class FlaxRobertaPreLayerNormForSequenceClassificationModule(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.classifier = FlaxRobertaPreLayerNormClassificationHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta_prelayernorm( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - logits = self.classifier(sequence_output, deterministic=deterministic) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - RobertaPreLayerNorm Model transformer with a sequence classification/regression head on top (a linear layer on top - of the pooled output) e.g. for GLUE tasks. - """, - ROBERTA_PRELAYERNORM_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForSequenceClassification with Roberta->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormForSequenceClassification(FlaxRobertaPreLayerNormPreTrainedModel): - module_class = FlaxRobertaPreLayerNormForSequenceClassificationModule - - -append_call_sample_docstring( - FlaxRobertaPreLayerNormForSequenceClassification, - _CHECKPOINT_FOR_DOC, - FlaxSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->RobertaPreLayerNorm, with self.bert->self.roberta_prelayernorm -class FlaxRobertaPreLayerNormForMultipleChoiceModule(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.classifier = nn.Dense(1, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - num_choices = input_ids.shape[1] - input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None - token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None - position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None - - # Model - outputs = self.roberta_prelayernorm( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) - - reshaped_logits = logits.reshape(-1, num_choices) - - if not return_dict: - return (reshaped_logits,) + outputs[2:] - - return FlaxMultipleChoiceModelOutput( - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - RobertaPreLayerNorm Model with a multiple choice classification head on top (a linear layer on top of the pooled - output and a softmax) e.g. for RocStories/SWAG tasks. - """, - ROBERTA_PRELAYERNORM_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMultipleChoice with Roberta->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormForMultipleChoice(FlaxRobertaPreLayerNormPreTrainedModel): - module_class = FlaxRobertaPreLayerNormForMultipleChoiceModule - - -overwrite_call_docstring( - FlaxRobertaPreLayerNormForMultipleChoice, - ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"), -) -append_call_sample_docstring( - FlaxRobertaPreLayerNormForMultipleChoice, - _CHECKPOINT_FOR_DOC, - FlaxMultipleChoiceModelOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->RobertaPreLayerNorm, with self.bert->self.roberta_prelayernorm -class FlaxRobertaPreLayerNormForTokenClassificationModule(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(rate=classifier_dropout) - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta_prelayernorm( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.classifier(hidden_states) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxTokenClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - RobertaPreLayerNorm Model with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - ROBERTA_PRELAYERNORM_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForTokenClassification with Roberta->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormForTokenClassification(FlaxRobertaPreLayerNormPreTrainedModel): - module_class = FlaxRobertaPreLayerNormForTokenClassificationModule - - -append_call_sample_docstring( - FlaxRobertaPreLayerNormForTokenClassification, - _CHECKPOINT_FOR_DOC, - FlaxTokenClassifierOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->RobertaPreLayerNorm, with self.bert->self.roberta_prelayernorm -class FlaxRobertaPreLayerNormForQuestionAnsweringModule(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta_prelayernorm( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - logits = self.qa_outputs(hidden_states) - start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - return (start_logits, end_logits) + outputs[1:] - - return FlaxQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - RobertaPreLayerNorm Model with a span classification head on top for extractive question-answering tasks like SQuAD - (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - ROBERTA_PRELAYERNORM_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForQuestionAnswering with Roberta->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormForQuestionAnswering(FlaxRobertaPreLayerNormPreTrainedModel): - module_class = FlaxRobertaPreLayerNormForQuestionAnsweringModule - - -append_call_sample_docstring( - FlaxRobertaPreLayerNormForQuestionAnswering, - _CHECKPOINT_FOR_DOC, - FlaxQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLMModule with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm -class FlaxRobertaPreLayerNormForCausalLMModule(nn.Module): - config: RobertaPreLayerNormConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.lm_head = FlaxRobertaPreLayerNormLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - token_type_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta_prelayernorm( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.roberta_prelayernorm.variables["params"]["embeddings"]["word_embeddings"][ - "embedding" - ] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxCausalLMOutputWithCrossAttentions( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - """ - RobertaPreLayerNorm Model with a language modeling head on top (a linear layer on top of the hidden-states output) - e.g for autoregressive tasks. - """, - ROBERTA_PRELAYERNORM_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLM with Roberta->RobertaPreLayerNorm -class FlaxRobertaPreLayerNormForCausalLM(FlaxRobertaPreLayerNormPreTrainedModel): - module_class = FlaxRobertaPreLayerNormForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyway. - # Thus, we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxRobertaPreLayerNormForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutputWithCrossAttentions, - _CONFIG_FOR_DOC, -) - - -__all__ = [ - "FlaxRobertaPreLayerNormForCausalLM", - "FlaxRobertaPreLayerNormForMaskedLM", - "FlaxRobertaPreLayerNormForMultipleChoice", - "FlaxRobertaPreLayerNormForQuestionAnswering", - "FlaxRobertaPreLayerNormForSequenceClassification", - "FlaxRobertaPreLayerNormForTokenClassification", - "FlaxRobertaPreLayerNormModel", - "FlaxRobertaPreLayerNormPreTrainedModel", -] diff --git a/src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py deleted file mode 100644 index 0a370f390269..000000000000 --- a/src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py +++ /dev/null @@ -1,1807 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 RoBERTa-PreLayerNorm model.""" - -from __future__ import annotations - -import math -import warnings - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFBaseModelOutputWithPoolingAndCrossAttentions, - TFCausalLMOutputWithCrossAttentions, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_roberta_prelayernorm import RobertaPreLayerNormConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "andreasmadsen/efficient_mlm_m0.40" -_CONFIG_FOR_DOC = "RobertaPreLayerNormConfig" - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaEmbeddings with Roberta->RobertaPreLayerNorm -class TFRobertaPreLayerNormEmbeddings(keras.layers.Layer): - """ - Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. - """ - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.padding_idx = 1 - self.config = config - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding - symbols are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - input_ids: tf.Tensor - Returns: tf.Tensor - """ - mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) - incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask - - return incremental_indices + self.padding_idx - - def call( - self, - input_ids=None, - position_ids=None, - token_type_ids=None, - inputs_embeds=None, - past_key_values_length=0, - training=False, - ): - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - if input_ids is not None: - # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = self.create_position_ids_from_input_ids( - input_ids=input_ids, past_key_values_length=past_key_values_length - ) - else: - position_ids = tf.expand_dims( - tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 - ) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->RobertaPreLayerNorm -class TFRobertaPreLayerNormPooler(keras.layers.Layer): - def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->RobertaPreLayerNorm -class TFRobertaPreLayerNormSelfAttention(keras.layers.Layer): - def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - - self.is_decoder = config.is_decoder - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concat([past_key_value[0], key_layer], axis=2) - value_layer = tf.concat([past_key_value[1], value_layer], axis=2) - else: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFRobertaPreLayerNormModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -class TFRobertaPreLayerNormSelfOutput(keras.layers.Layer): - def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = hidden_states + input_tensor - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFRobertaPreLayerNormAttention(keras.layers.Layer): - def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFRobertaPreLayerNormSelfAttention(config, name="self") - self.dense_output = TFRobertaPreLayerNormSelfOutput(config, name="output") - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention.prune_heads - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - hidden_states_pre_layer_norm = self.LayerNorm(inputs=input_tensor) - self_outputs = self.self_attention( - hidden_states=hidden_states_pre_layer_norm, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - # add attentions (possibly with past_key_value) if we output them - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFRobertaPreLayerNormIntermediate(keras.layers.Layer): - def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): - super().__init__(**kwargs) - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.LayerNorm(inputs=hidden_states) - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFRobertaPreLayerNormOutput(keras.layers.Layer): - def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = hidden_states + input_tensor - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->RobertaPreLayerNorm -class TFRobertaPreLayerNormLayer(keras.layers.Layer): - def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFRobertaPreLayerNormAttention(config, name="attention") - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TFRobertaPreLayerNormAttention(config, name="crossattention") - self.intermediate = TFRobertaPreLayerNormIntermediate(config, name="intermediate") - self.bert_output = TFRobertaPreLayerNormOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_value: tuple[tf.Tensor] | None, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=self_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - input_tensor=attention_output, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + outputs # add attentions if we output them - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - if getattr(self, "crossattention", None) is not None: - with tf.name_scope(self.crossattention.name): - self.crossattention.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->RobertaPreLayerNorm -class TFRobertaPreLayerNormEncoder(keras.layers.Layer): - def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layer = [TFRobertaPreLayerNormLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_values: tuple[tuple[tf.Tensor]] | None, - use_cache: bool | None, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - past_key_value = past_key_values[i] if past_key_values is not None else None - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - if self.config.add_cross_attention and encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None - ) - - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFRobertaPreLayerNormMainLayer(keras.layers.Layer): - config_class = RobertaPreLayerNormConfig - - def __init__(self, config, add_pooling_layer=True, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.is_decoder = config.is_decoder - - self.num_hidden_layers = config.num_hidden_layers - self.initializer_range = config.initializer_range - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - self.encoder = TFRobertaPreLayerNormEncoder(config, name="encoder") - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.pooler = TFRobertaPreLayerNormPooler(config, name="pooler") if add_pooling_layer else None - # The embeddings must be the last declaration in order to follow the weights order - self.embeddings = TFRobertaPreLayerNormEmbeddings(config, name="embeddings") - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - if not self.config.is_decoder: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_key_values_length = 0 - past_key_values = [None] * len(self.encoder.layer) - else: - past_key_values_length = shape_list(past_key_values[0][0])[-2] - - if attention_mask is None: - attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - training=training, - ) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(attention_mask) - - mask_seq_length = seq_length + past_key_values_length - # Provided a padding mask of dimensions [batch_size, mask_seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if self.is_decoder: - seq_ids = tf.range(mask_seq_length) - causal_mask = tf.less_equal( - tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), - seq_ids[None, :, None], - ) - causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) - extended_attention_mask = causal_mask * attention_mask[:, None, :] - attention_mask_shape = shape_list(extended_attention_mask) - extended_attention_mask = tf.reshape( - extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) - ) - if past_key_values[0] is not None: - # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] - extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] - else: - extended_attention_mask = tf.reshape( - attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - if self.is_decoder and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) - num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) - if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, - # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) - - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - sequence_output = self.LayerNorm(inputs=sequence_output) - pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaPreTrainedModel with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm -class TFRobertaPreLayerNormPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RobertaPreLayerNormConfig - base_model_prefix = "roberta_prelayernorm" - - -ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`RobertaPreLayerNormConfig`]): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare RoBERTa-PreLayerNorm Model transformer outputting raw hidden-states without any specific head on top.", - ROBERTA_PRELAYERNORM_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaModel with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm -class TFRobertaPreLayerNormModel(TFRobertaPreLayerNormPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer(config, name="roberta_prelayernorm") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutputWithPoolingAndCrossAttentions: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - """ - outputs = self.roberta_prelayernorm( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta_prelayernorm", None) is not None: - with tf.name_scope(self.roberta_prelayernorm.name): - self.roberta_prelayernorm.build(None) - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->RobertaPreLayerNorm -class TFRobertaPreLayerNormLMHead(keras.layers.Layer): - """RobertaPreLayerNorm Head for masked language modeling.""" - - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.act = get_tf_activation("gelu") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - - def get_output_embeddings(self): - return self.decoder - - def set_output_embeddings(self, value): - self.decoder.weight = value - self.decoder.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.layer_norm(hidden_states) - - # project back to size of vocabulary with bias - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -@add_start_docstrings( - """RoBERTa-PreLayerNorm Model with a `language modeling` head on top.""", ROBERTA_PRELAYERNORM_START_DOCSTRING -) -class TFRobertaPreLayerNormForMaskedLM(TFRobertaPreLayerNormPreTrainedModel, TFMaskedLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] - - # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMaskedLM.__init__ with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer( - config, add_pooling_layer=False, name="roberta_prelayernorm" - ) - self.lm_head = TFRobertaPreLayerNormLMHead(config, self.roberta_prelayernorm.embeddings, name="lm_head") - - def get_lm_head(self): - return self.lm_head - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - mask="", - expected_output="' Paris'", - expected_loss=0.69, - ) - # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMaskedLM.call with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - outputs = self.roberta_prelayernorm( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta_prelayernorm", None) is not None: - with tf.name_scope(self.roberta_prelayernorm.name): - self.roberta_prelayernorm.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForCausalLM with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm -class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFCausalLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] - - def __init__(self, config: RobertaPreLayerNormConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if not config.is_decoder: - logger.warning( - "If you want to use `TFRobertaPreLayerNormLMHeadModel` as a standalone, add `is_decoder=True.`" - ) - - self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer( - config, add_pooling_layer=False, name="roberta_prelayernorm" - ) - self.lm_head = TFRobertaPreLayerNormLMHead( - config, input_embeddings=self.roberta_prelayernorm.embeddings, name="lm_head" - ) - - def get_lm_head(self): - return self.lm_head - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = tf.ones(input_shape) - - # cut decoder_input_ids if past is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - outputs = self.roberta_prelayernorm( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - logits = self.lm_head(hidden_states=sequence_output, training=training) - loss = None - - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithCrossAttentions( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta_prelayernorm", None) is not None: - with tf.name_scope(self.roberta_prelayernorm.name): - self.roberta_prelayernorm.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->RobertaPreLayerNorm -class TFRobertaPreLayerNormClassificationHead(keras.layers.Layer): - """Head for sentence-level classification tasks.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.out_proj = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" - ) - self.config = config - - def call(self, features, training=False): - x = features[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x, training=training) - x = self.dense(x) - x = self.dropout(x, training=training) - x = self.out_proj(x) - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RoBERTa-PreLayerNorm Model transformer with a sequence classification/regression head on top (a linear layer on top - of the pooled output) e.g. for GLUE tasks. - """, - ROBERTA_PRELAYERNORM_START_DOCSTRING, -) -class TFRobertaPreLayerNormForSequenceClassification( - TFRobertaPreLayerNormPreTrainedModel, TFSequenceClassificationLoss -): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer( - config, add_pooling_layer=False, name="roberta_prelayernorm" - ) - self.classifier = TFRobertaPreLayerNormClassificationHead(config, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForSequenceClassification.call with roberta->roberta_prelayernorm - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.roberta_prelayernorm( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.classifier(sequence_output, training=training) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta_prelayernorm", None) is not None: - with tf.name_scope(self.roberta_prelayernorm.name): - self.roberta_prelayernorm.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - RobertaPreLayerNorm Model with a multiple choice classification head on top (a linear layer on top of the pooled - output and a softmax) e.g. for RocStories/SWAG tasks. - """, - ROBERTA_PRELAYERNORM_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMultipleChoice with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm -class TFRobertaPreLayerNormForMultipleChoice(TFRobertaPreLayerNormPreTrainedModel, TFMultipleChoiceLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"lm_head"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer(config, name="roberta_prelayernorm") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward( - ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") - ) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - outputs = self.roberta_prelayernorm( - flat_input_ids, - flat_attention_mask, - flat_token_type_ids, - flat_position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, training=training) - logits = self.classifier(pooled_output) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta_prelayernorm", None) is not None: - with tf.name_scope(self.roberta_prelayernorm.name): - self.roberta_prelayernorm.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RoBERTa-PreLayerNorm Model with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - ROBERTA_PRELAYERNORM_START_DOCSTRING, -) -class TFRobertaPreLayerNormForTokenClassification(TFRobertaPreLayerNormPreTrainedModel, TFTokenClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer( - config, add_pooling_layer=False, name="roberta_prelayernorm" - ) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForTokenClassification.call with roberta->roberta_prelayernorm - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.roberta_prelayernorm( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta_prelayernorm", None) is not None: - with tf.name_scope(self.roberta_prelayernorm.name): - self.roberta_prelayernorm.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RoBERTa-PreLayerNorm Model with a span classification head on top for extractive question-answering tasks like - SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - ROBERTA_PRELAYERNORM_START_DOCSTRING, -) -class TFRobertaPreLayerNormForQuestionAnswering(TFRobertaPreLayerNormPreTrainedModel, TFQuestionAnsweringLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer( - config, add_pooling_layer=False, name="roberta_prelayernorm" - ) - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForQuestionAnswering.call with roberta->roberta_prelayernorm - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.roberta_prelayernorm( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - - loss = None - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta_prelayernorm", None) is not None: - with tf.name_scope(self.roberta_prelayernorm.name): - self.roberta_prelayernorm.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFRobertaPreLayerNormForCausalLM", - "TFRobertaPreLayerNormForMaskedLM", - "TFRobertaPreLayerNormForMultipleChoice", - "TFRobertaPreLayerNormForQuestionAnswering", - "TFRobertaPreLayerNormForSequenceClassification", - "TFRobertaPreLayerNormForTokenClassification", - "TFRobertaPreLayerNormMainLayer", - "TFRobertaPreLayerNormModel", - "TFRobertaPreLayerNormPreTrainedModel", -] diff --git a/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py deleted file mode 100755 index d227948e0ee3..000000000000 --- a/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,62 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert RoFormer checkpoint.""" - -import argparse - -import torch - -from transformers import RoFormerConfig, RoFormerForMaskedLM, load_tf_weights_in_roformer -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): - # Initialise PyTorch model - config = RoFormerConfig.from_json_file(bert_config_file) - print(f"Building PyTorch model from configuration: {config}") - model = RoFormerForMaskedLM(config) - - # Load weights from tf checkpoint - load_tf_weights_in_roformer(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - torch.save(model.state_dict(), pytorch_dump_path, _use_new_zipfile_serialization=False) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--bert_config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained BERT model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/roformer/modeling_flax_roformer.py b/src/transformers/models/roformer/modeling_flax_roformer.py deleted file mode 100644 index de78eb4787c0..000000000000 --- a/src/transformers/models/roformer/modeling_flax_roformer.py +++ /dev/null @@ -1,1091 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax RoFormer model.""" - -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxMaskedLMOutput, - FlaxMultipleChoiceModelOutput, - FlaxQuestionAnsweringModelOutput, - FlaxSequenceClassifierOutput, - FlaxTokenClassifierOutput, -) -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_roformer import RoFormerConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base" -_CONFIG_FOR_DOC = "RoFormerConfig" - - -ROFORMER_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`RoFormerConfig`]): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -ROFORMER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - head_mask (`numpy.ndarray` of shape `({0})`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions -def create_sinusoidal_positions(n_pos, dim): - position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) - sentinel = dim // 2 + dim % 2 - out = np.zeros_like(position_enc) - out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) - out[:, sentinel:] = np.cos(position_enc[:, 1::2]) - - return jnp.array(out) - - -class FlaxRoFormerEmbeddings(nn.Module): - """Construct the embeddings from word and token_type embeddings.""" - - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.token_type_embeddings = nn.Embed( - self.config.type_vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, input_ids, token_type_ids, attention_mask, deterministic: bool = True): - # Embed - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) - token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) - - # Sum all embeddings - hidden_states = inputs_embeds + token_type_embeddings - - # Layer Norm - hidden_states = self.LayerNorm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxRoFormerSelfAttention(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self) -> None: - if self.config.hidden_size % self.config.num_attention_heads != 0: - raise ValueError( - "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " - " : {self.config.num_attention_heads}" - ) - - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - self.rotary_value = self.config.rotary_value - - def __call__( - self, - hidden_states, - attention_mask, - sinusoidal_pos, - layer_head_mask, - deterministic=True, - output_attentions: bool = False, - ): - head_dim = self.config.hidden_size // self.config.num_attention_heads - - query_states = self.query(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - value_states = self.value(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - key_states = self.key(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - - if sinusoidal_pos is not None: - if self.rotary_value: - query_states, key_states, value_states = self.apply_rotary_position_embeddings( - sinusoidal_pos, query_states, key_states, value_states - ) - else: - query_states, key_states = self.apply_rotary_position_embeddings( - sinusoidal_pos, query_states, key_states - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - @staticmethod - def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None): - sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1) - sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape) - cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape) - - def rotate_layer(layer, sin_pos, cos_pos): - rotate_half_layer = jnp.stack([-layer[..., 1::2], layer[..., ::2]], axis=-1).reshape(layer.shape) - rotary_matrix_cos = jnp.einsum("bslh,...sh->bslh", layer, cos_pos) - rotary_matrix_sin = jnp.einsum("bslh,...sh->bslh", rotate_half_layer, sin_pos) - return rotary_matrix_cos + rotary_matrix_sin - - query_layer = rotate_layer(query_layer, sin_pos, cos_pos) - key_layer = rotate_layer(key_layer, sin_pos, cos_pos) - if value_layer is not None: - value_layer = rotate_layer(value_layer, sin_pos, cos_pos) - return query_layer, key_layer, value_layer - return query_layer, key_layer - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->RoFormer -class FlaxRoFormerSelfOutput(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, input_tensor, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class FlaxRoFormerAttention(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.self = FlaxRoFormerSelfAttention(self.config, dtype=self.dtype) - self.output = FlaxRoFormerSelfOutput(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - sinusoidal_pos, - layer_head_mask, - deterministic=True, - output_attentions: bool = False, - ): - # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) - # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable - # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - attn_outputs = self.self( - hidden_states, - attention_mask, - sinusoidal_pos, - layer_head_mask=layer_head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_outputs[1],) - - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->RoFormer -class FlaxRoFormerIntermediate(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.activation = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->RoFormer -class FlaxRoFormerOutput(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states, attention_output, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + attention_output) - return hidden_states - - -class FlaxRoFormerLayer(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.attention = FlaxRoFormerAttention(self.config, dtype=self.dtype) - self.intermediate = FlaxRoFormerIntermediate(self.config, dtype=self.dtype) - self.output = FlaxRoFormerOutput(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - sinusiodal_pos, - layer_head_mask, - deterministic: bool = True, - output_attentions: bool = False, - ): - attention_outputs = self.attention( - hidden_states, - attention_mask, - sinusiodal_pos, - layer_head_mask=layer_head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = attention_outputs[0] - - hidden_states = self.intermediate(attention_output) - hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[1],) - return outputs - - -class FlaxRoFormerLayerCollection(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxRoFormerLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask, - sinusoidal_pos, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - # Check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.shape[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for " - f" {head_mask.shape[0]}." - ) - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer( - hidden_states, - attention_mask, - sinusoidal_pos, - layer_head_mask=head_mask[i] if head_mask is not None else None, - deterministic=deterministic, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states,) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -class FlaxRoFormerEncoder(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.embed_positions = create_sinusoidal_positions( - self.config.max_position_embeddings, self.config.hidden_size // self.config.num_attention_heads - ) - self.layer = FlaxRoFormerLayerCollection(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - sinusoidal_pos = self.embed_positions[: hidden_states.shape[1], :] - - return self.layer( - hidden_states, - attention_mask, - sinusoidal_pos, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPredictionHeadTransform with Bert->RoFormer -class FlaxRoFormerPredictionHeadTransform(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) - self.activation = ACT2FN[self.config.hidden_act] - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return self.LayerNorm(hidden_states) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLMPredictionHead with Bert->RoFormer -class FlaxRoFormerLMPredictionHead(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.transform = FlaxRoFormerPredictionHeadTransform(self.config, dtype=self.dtype) - self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) - self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) - - def __call__(self, hidden_states, shared_embedding=None): - hidden_states = self.transform(hidden_states) - - if shared_embedding is not None: - hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - hidden_states = self.decoder(hidden_states) - - bias = jnp.asarray(self.bias, self.dtype) - hidden_states += bias - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOnlyMLMHead with Bert->RoFormer -class FlaxRoFormerOnlyMLMHead(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.predictions = FlaxRoFormerLMPredictionHead(self.config, dtype=self.dtype) - - def __call__(self, hidden_states, shared_embedding=None): - hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding) - return hidden_states - - -class FlaxRoFormerClassificationHead(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.out_proj = nn.Dense( - self.config.num_labels, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.activation = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_states, deterministic=True): - hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RoFormerConfig - base_model_prefix = "roformer" - module_class: nn.Module = None - - def __init__( - self, - config: RoFormerConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - token_type_ids = jnp.zeros_like(input_ids) - attention_mask = jnp.ones_like(input_ids) - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - head_mask=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # init input tensors if not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - if head_mask is None: - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(token_type_ids, dtype="i4"), - jnp.array(head_mask, dtype="i4"), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) - - -class FlaxRoFormerModule(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.embeddings = FlaxRoFormerEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxRoFormerEncoder(self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - hidden_states = self.embeddings(input_ids, token_type_ids, attention_mask, deterministic=deterministic) - outputs = self.encoder( - hidden_states, - attention_mask, - head_mask=head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - - if not return_dict: - return (hidden_states,) + outputs[1:] - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - "The bare RoFormer Model transformer outputting raw hidden-states without any specific head on top.", - ROFORMER_START_DOCSTRING, -) -class FlaxRoFormerModel(FlaxRoFormerPreTrainedModel): - module_class = FlaxRoFormerModule - - -append_call_sample_docstring(FlaxRoFormerModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC) - - -class FlaxRoFormerForMaskedLMModule(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) - self.cls = FlaxRoFormerOnlyMLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roformer( - input_ids, - attention_mask, - token_type_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.roformer.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.cls(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxMaskedLMOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING) -class FlaxRoFormerForMaskedLM(FlaxRoFormerPreTrainedModel): - module_class = FlaxRoFormerForMaskedLMModule - - -append_call_sample_docstring( - FlaxRoFormerForMaskedLM, - _CHECKPOINT_FOR_DOC, - FlaxMaskedLMOutput, - _CONFIG_FOR_DOC, - mask="", -) - - -class FlaxRoFormerForSequenceClassificationModule(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) - self.classifier = FlaxRoFormerClassificationHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roformer( - input_ids, - attention_mask, - token_type_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - logits = self.classifier(sequence_output, deterministic=deterministic) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - RoFormer Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - ROFORMER_START_DOCSTRING, -) -class FlaxRoFormerForSequenceClassification(FlaxRoFormerPreTrainedModel): - module_class = FlaxRoFormerForSequenceClassificationModule - - -append_call_sample_docstring( - FlaxRoFormerForSequenceClassification, - _CHECKPOINT_FOR_DOC, - FlaxSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxRoFormerForMultipleChoiceModule(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.classifier = nn.Dense(1, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - num_choices = input_ids.shape[1] - input_ids = input_ids.reshape(-1, input_ids.shape[-1]) - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) - token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) - - # Model - outputs = self.roformer( - input_ids, - attention_mask, - token_type_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # Equivalent to sequence_summary call in the PyTorch implementation - hidden_states = outputs[0] - pooled_output = hidden_states[:, -1] - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - - logits = self.classifier(pooled_output) - - reshaped_logits = logits.reshape(-1, num_choices) - - if not return_dict: - return (reshaped_logits,) + outputs[2:] - - return FlaxMultipleChoiceModelOutput( - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - RoFormer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - ROFORMER_START_DOCSTRING, -) -class FlaxRoFormerForMultipleChoice(FlaxRoFormerPreTrainedModel): - module_class = FlaxRoFormerForMultipleChoiceModule - - -overwrite_call_docstring( - FlaxRoFormerForMultipleChoice, ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") -) -append_call_sample_docstring( - FlaxRoFormerForMultipleChoice, - _CHECKPOINT_FOR_DOC, - FlaxMultipleChoiceModelOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxRoFormerForTokenClassificationModule(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roformer( - input_ids, - attention_mask, - token_type_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.classifier(hidden_states) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxTokenClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - RoFormer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - ROFORMER_START_DOCSTRING, -) -class FlaxRoFormerForTokenClassification(FlaxRoFormerPreTrainedModel): - module_class = FlaxRoFormerForTokenClassificationModule - - -append_call_sample_docstring( - FlaxRoFormerForTokenClassification, - _CHECKPOINT_FOR_DOC, - FlaxTokenClassifierOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxRoFormerForQuestionAnsweringModule(nn.Module): - config: RoFormerConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) - self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roformer( - input_ids, - attention_mask, - token_type_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - logits = self.qa_outputs(hidden_states) - start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - return (start_logits, end_logits) + outputs[1:] - - return FlaxQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - RoFormer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - ROFORMER_START_DOCSTRING, -) -class FlaxRoFormerForQuestionAnswering(FlaxRoFormerPreTrainedModel): - module_class = FlaxRoFormerForQuestionAnsweringModule - - -append_call_sample_docstring( - FlaxRoFormerForQuestionAnswering, - _CHECKPOINT_FOR_DOC, - FlaxQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) - - -__all__ = [ - "FlaxRoFormerForMaskedLM", - "FlaxRoFormerForMultipleChoice", - "FlaxRoFormerForQuestionAnswering", - "FlaxRoFormerForSequenceClassification", - "FlaxRoFormerForTokenClassification", - "FlaxRoFormerModel", - "FlaxRoFormerPreTrainedModel", -] diff --git a/src/transformers/models/roformer/modeling_tf_roformer.py b/src/transformers/models/roformer/modeling_tf_roformer.py deleted file mode 100644 index e07374e9fdf5..000000000000 --- a/src/transformers/models/roformer/modeling_tf_roformer.py +++ /dev/null @@ -1,1546 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 RoFormer model.""" - -from __future__ import annotations - -import math - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPooling, - TFCausalLMOutput, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFSequenceSummary, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_roformer import RoFormerConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base" -_CONFIG_FOR_DOC = "RoFormerConfig" - - -class TFRoFormerSinusoidalPositionalEmbedding(keras.layers.Layer): - """This module produces sinusoidal positional embeddings of any length.""" - - def __init__(self, num_positions: int, embedding_dim: int, **kwargs): - super().__init__(**kwargs) - - if embedding_dim % 2 != 0: - raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") - - self.embedding_dim = embedding_dim - self.num_positions = num_positions - - def build(self, input_shape: tf.TensorShape): - """ - Build shared token embedding layer Shared weights logic adapted from - https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 - """ - - weight = self._init_weight(self.num_positions, self.embedding_dim) - - self.weight = self.add_weight( - name="embeddings", - shape=[self.num_positions, self.embedding_dim], - ) - weight = tf.cast(weight, dtype=self.weight.dtype) - - self.weight.assign(weight) - - super().build(input_shape) - - @staticmethod - def _init_weight(n_pos: int, dim: int): - """ - Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in - the 2nd half of the vector. [dim // 2:] - """ - position_enc = np.array( - [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] - ) - table = np.zeros_like(position_enc) - # index 0 is all zero - table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) - table[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) - # convert to tensor - table = tf.convert_to_tensor(table) - tf.stop_gradient(table) - return table - - def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): - """Input is expected to be of size [bsz x seqlen].""" - bsz, seq_len = input_shape[:2] - - positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") - return tf.gather(self.weight, positions) - - -class TFRoFormerEmbeddings(keras.layers.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config: RoFormerConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = config.embedding_size - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.embedding_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.embedding_size]) - - def call( - self, - input_ids: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - training: bool = False, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -class TFRoFormerSelfAttention(keras.layers.Layer): - def __init__(self, config: RoFormerConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - self.rotary_value = config.rotary_value - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - sinusoidal_pos: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - mixed_key_layer = self.key(inputs=hidden_states) - mixed_value_layer = self.value(inputs=hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - if sinusoidal_pos is not None: - if self.rotary_value: - query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings( - sinusoidal_pos, query_layer, key_layer, value_layer - ) - else: - query_layer, key_layer = self.apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFRoFormerModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - return outputs - - @staticmethod - def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None): - # https://kexue.fm/archives/8265 - # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2] - # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] - sin, cos = tf.split(sinusoidal_pos, num_or_size_splits=2, axis=-1) - # sin [θ0,θ1,θ2......θd/2-1]-> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] - # cos [θ0,θ1,θ2......θd/2-1]-> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] - sin_pos = tf.repeat(sin, 2, axis=-1) - cos_pos = tf.repeat(cos, 2, axis=-1) - # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] - rotate_half_query_layer = tf.stack([-query_layer[..., 1::2], query_layer[..., ::2]], axis=-1) - rotate_half_query_layer = tf.reshape(rotate_half_query_layer, shape_list(query_layer)) - query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos - # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] - rotate_half_key_layer = tf.stack([-key_layer[..., 1::2], key_layer[..., ::2]], axis=-1) - rotate_half_key_layer = tf.reshape(rotate_half_key_layer, shape_list(key_layer)) - key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos - if value_layer is not None: - # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] - rotate_half_value_layer = tf.stack([-value_layer[..., 1::2], value_layer[..., ::2]], axis=-1) - rotate_half_value_layer = tf.reshape(rotate_half_value_layer, shape_list(value_layer)) - value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos - return query_layer, key_layer, value_layer - return query_layer, key_layer - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->RoFormer -class TFRoFormerSelfOutput(keras.layers.Layer): - def __init__(self, config: RoFormerConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFRoFormerAttention(keras.layers.Layer): - def __init__(self, config: RoFormerConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFRoFormerSelfAttention(config, name="self") - self.dense_output = TFRoFormerSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - sinusoidal_pos: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, - attention_mask=attention_mask, - sinusoidal_pos=sinusoidal_pos, - head_mask=head_mask, - output_attentions=output_attentions, - training=training, - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->RoFormer -class TFRoFormerIntermediate(keras.layers.Layer): - def __init__(self, config: RoFormerConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->RoFormer -class TFRoFormerOutput(keras.layers.Layer): - def __init__(self, config: RoFormerConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -class TFRoFormerLayer(keras.layers.Layer): - def __init__(self, config: RoFormerConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFRoFormerAttention(config, name="attention") - self.intermediate = TFRoFormerIntermediate(config, name="intermediate") - self.roformer_output = TFRoFormerOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - sinusoidal_pos: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - sinusoidal_pos=sinusoidal_pos, - head_mask=head_mask, - output_attentions=output_attentions, - training=training, - ) - attention_output = attention_outputs[0] - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.roformer_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "roformer_output", None) is not None: - with tf.name_scope(self.roformer_output.name): - self.roformer_output.build(None) - - -class TFRoFormerEncoder(keras.layers.Layer): - def __init__(self, config: RoFormerConfig, **kwargs): - super().__init__(**kwargs) - self.embed_positions = TFRoFormerSinusoidalPositionalEmbedding( - config.max_position_embeddings, - config.hidden_size // config.num_attention_heads, - name="embed_positions", - ) - self.layer = [TFRoFormerLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head] - sinusoidal_pos = self.embed_positions(shape_list(hidden_states)[:-1])[None, None, :, :] - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - sinusoidal_pos=sinusoidal_pos, - head_mask=head_mask[i], - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFRoFormerPredictionHeadTransform(keras.layers.Layer): - def __init__(self, config: RoFormerConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.embedding_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - - if isinstance(config.hidden_act, str): - self.transform_act_fn = get_tf_activation(config.hidden_act) - else: - self.transform_act_fn = config.hidden_act - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(inputs=hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.embedding_size]) - - -class TFRoFormerLMPredictionHead(keras.layers.Layer): - def __init__(self, config: RoFormerConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.embedding_size = config.embedding_size - - self.transform = TFRoFormerPredictionHeadTransform(config, name="transform") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.input_embeddings = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "transform", None) is not None: - with tf.name_scope(self.transform.name): - self.transform.build(None) - - def get_output_embeddings(self) -> keras.layers.Layer: - return self.input_embeddings - - def set_output_embeddings(self, value: tf.Variable): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self) -> dict[str, tf.Variable]: - return {"bias": self.bias} - - def set_bias(self, value: tf.Variable): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.transform(hidden_states=hidden_states) - seq_length = shape_list(hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->RoFormer -class TFRoFormerMLMHead(keras.layers.Layer): - def __init__(self, config: RoFormerConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.predictions = TFRoFormerLMPredictionHead(config, input_embeddings, name="predictions") - - def call(self, sequence_output: tf.Tensor) -> tf.Tensor: - prediction_scores = self.predictions(hidden_states=sequence_output) - - return prediction_scores - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - -@keras_serializable -class TFRoFormerMainLayer(keras.layers.Layer): - config_class = RoFormerConfig - - def __init__(self, config: RoFormerConfig, add_pooling_layer: bool = True, **kwargs): - super().__init__(**kwargs) - - self.config = config - - self.embeddings = TFRoFormerEmbeddings(config, name="embeddings") - if config.embedding_size != config.hidden_size: - self.embeddings_project = keras.layers.Dense(config.hidden_size, name="embeddings_project") - - self.encoder = TFRoFormerEncoder(config, name="encoder") - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - training=training, - ) - if hasattr(self, "embeddings_project"): - embedding_output = self.embeddings_project(embedding_output, training=training) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - - if not return_dict: - return (sequence_output,) + encoder_outputs[1:] - - return TFBaseModelOutput( - last_hidden_state=sequence_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "embeddings_project", None) is not None: - with tf.name_scope(self.embeddings_project.name): - self.embeddings_project.build([None, None, self.config.embedding_size]) - - -class TFRoFormerPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RoFormerConfig - base_model_prefix = "roformer" - - -ROFORMER_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`RoFormerConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -ROFORMER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare RoFormer Model transformer outputting raw hidden-states without any specific head on top.", - ROFORMER_START_DOCSTRING, -) -class TFRoFormerModel(TFRoFormerPreTrainedModel): - def __init__(self, config: RoFormerConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.roformer = TFRoFormerMainLayer(config, name="roformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPooling, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - outputs = self.roformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roformer", None) is not None: - with tf.name_scope(self.roformer.name): - self.roformer.build(None) - - -@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING) -class TFRoFormerForMaskedLM(TFRoFormerPreTrainedModel, TFMaskedLanguageModelingLoss): - def __init__(self, config: RoFormerConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if config.is_decoder: - logger.warning( - "If you want to use `TFRoFormerForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.roformer = TFRoFormerMainLayer(config, name="roformer") - self.mlm = TFRoFormerMLMHead(config, input_embeddings=self.roformer.embeddings, name="mlm___cls") - - def get_lm_head(self) -> keras.layers.Layer: - return self.mlm.predictions - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - outputs = self.roformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.mlm(sequence_output=sequence_output, training=training) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roformer", None) is not None: - with tf.name_scope(self.roformer.name): - self.roformer.build(None) - if getattr(self, "mlm", None) is not None: - with tf.name_scope(self.mlm.name): - self.mlm.build(None) - - -@add_start_docstrings( - """RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING -) -class TFRoFormerForCausalLM(TFRoFormerPreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config: RoFormerConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if not config.is_decoder: - logger.warning("If you want to use `TFRoFormerForCausalLM` as a standalone, add `is_decoder=True.`") - - self.roformer = TFRoFormerMainLayer(config, name="roformer") - self.mlm = TFRoFormerMLMHead(config, input_embeddings=self.roformer.embeddings, name="mlm___cls") - - def get_lm_head(self) -> keras.layers.Layer: - return self.mlm.predictions - - @unpack_inputs - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFCausalLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - outputs = self.roformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.mlm(sequence_output=sequence_output, training=training) - loss = None - - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roformer", None) is not None: - with tf.name_scope(self.roformer.name): - self.roformer.build(None) - if getattr(self, "mlm", None) is not None: - with tf.name_scope(self.mlm.name): - self.mlm.build(None) - - -class TFRoFormerClassificationHead(keras.layers.Layer): - """Head for sentence-level classification tasks.""" - - def __init__(self, config: RoFormerConfig, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.out_proj = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" - ) - - if isinstance(config.hidden_act, str): - self.classifier_act_fn = get_tf_activation(config.hidden_act) - else: - self.classifier_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.classifier_act_fn(hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.out_proj(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RoFormer Model transformer with a sequence classification/regression head on top e.g., for GLUE tasks. - """, - ROFORMER_START_DOCSTRING, -) -class TFRoFormerForSequenceClassification(TFRoFormerPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: RoFormerConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.roformer = TFRoFormerMainLayer(config, name="roformer") - self.classifier = TFRoFormerClassificationHead(config, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.roformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - logits = self.classifier(hidden_states=outputs[0], training=training) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[1:] - - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roformer", None) is not None: - with tf.name_scope(self.roformer.name): - self.roformer.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - RoFormer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - ROFORMER_START_DOCSTRING, -) -class TFRoFormerForMultipleChoice(TFRoFormerPreTrainedModel, TFMultipleChoiceLoss): - def __init__(self, config: RoFormerConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.roformer = TFRoFormerMainLayer(config, name="roformer") - self.sequence_summary = TFSequenceSummary(config, config.initializer_range, name="sequence_summary") - self.classifier = keras.layers.Dense( - units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward( - ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") - ) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None - flat_attention_mask = ( - tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None - ) - flat_token_type_ids = ( - tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None - ) - flat_inputs_embeds = ( - tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - outputs = self.roformer( - input_ids=flat_input_ids, - attention_mask=flat_attention_mask, - token_type_ids=flat_token_type_ids, - head_mask=head_mask, - inputs_embeds=flat_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - logits = self.sequence_summary(inputs=outputs[0], training=training) - logits = self.classifier(inputs=logits) - reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[1:] - - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roformer", None) is not None: - with tf.name_scope(self.roformer.name): - self.roformer.build(None) - if getattr(self, "sequence_summary", None) is not None: - with tf.name_scope(self.sequence_summary.name): - self.sequence_summary.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RoFormer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - ROFORMER_START_DOCSTRING, -) -class TFRoFormerForTokenClassification(TFRoFormerPreTrainedModel, TFTokenClassificationLoss): - def __init__(self, config: RoFormerConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.roformer = TFRoFormerMainLayer(config, name="roformer") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.roformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(inputs=sequence_output, training=training) - logits = self.classifier(inputs=sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roformer", None) is not None: - with tf.name_scope(self.roformer.name): - self.roformer.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - RoFormer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - ROFORMER_START_DOCSTRING, -) -class TFRoFormerForQuestionAnswering(TFRoFormerPreTrainedModel, TFQuestionAnsweringLoss): - def __init__(self, config: RoFormerConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - - self.roformer = TFRoFormerMainLayer(config, name="roformer") - self.qa_outputs = keras.layers.Dense( - units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.roformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.qa_outputs(inputs=sequence_output) - start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) - start_logits = tf.squeeze(input=start_logits, axis=-1) - end_logits = tf.squeeze(input=end_logits, axis=-1) - loss = None - - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions, "end_position": end_positions} - loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roformer", None) is not None: - with tf.name_scope(self.roformer.name): - self.roformer.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFRoFormerForCausalLM", - "TFRoFormerForMaskedLM", - "TFRoFormerForMultipleChoice", - "TFRoFormerForQuestionAnswering", - "TFRoFormerForSequenceClassification", - "TFRoFormerForTokenClassification", - "TFRoFormerLayer", - "TFRoFormerModel", - "TFRoFormerPreTrainedModel", -] diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py deleted file mode 100644 index ac81288fa182..000000000000 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ /dev/null @@ -1,1723 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a -discrepancy, the original file should be regarded as the 'reference' version. -""" - -from __future__ import annotations - -import collections -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import ACT2FN -from ...modeling_tf_outputs import TFBaseModelOutput -from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs -from ...tf_utils import flatten, functional_layernorm -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "SamConfig" -_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge" - - -@dataclass -class TFSamVisionEncoderOutput(ModelOutput): - """ - Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection - layer to the pooler_output. - - Args: - image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - image_embeds: tf.Tensor | None = None - last_hidden_state: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFSamImageSegmentationOutput(ModelOutput): - """ - Base class for Segment-Anything model's output - - Args: - iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`): - The iou scores of the predicted masks. - pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`): - The predicted low resolutions masks. Needs to be post-processed by the processor - vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for - the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. - vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - iou_scores: tf.Tensor | None = None - pred_masks: tf.Tensor | None = None - vision_hidden_states: tuple[tf.Tensor, ...] | None = None - vision_attentions: tuple[tf.Tensor, ...] | None = None - mask_decoder_attentions: tuple[tf.Tensor, ...] | None = None - - -class TFSamPatchEmbeddings(keras.layers.Layer): - """ - This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial - `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a - Transformer. - """ - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.hidden_size - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_patches = num_patches - - self.projection = keras.layers.Conv2D( - hidden_size, kernel_size=patch_size, strides=patch_size, name="projection" - ) - - def call(self, pixel_values): - batch_size, num_channels, height, width = shape_list(pixel_values) - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." - ) - embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1])) - return embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, None, self.num_channels]) - - -class TFSamMLPBlock(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.lin1 = keras.layers.Dense(config.mlp_dim, name="lin1") - self.lin2 = keras.layers.Dense(config.hidden_size, name="lin2") - self.act = ACT2FN[config.hidden_act] - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.lin1(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.lin2(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "lin1", None) is not None: - with tf.name_scope(self.lin1.name): - self.lin1.build([None, None, self.config.hidden_size]) - if getattr(self, "lin2", None) is not None: - with tf.name_scope(self.lin2.name): - self.lin2.build([None, None, self.config.mlp_dim]) - - -class TFSamLayerNorm(keras.layers.Layer): - r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. - The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, - width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). - """ - - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs): - super().__init__(**kwargs) - self.eps = eps - self.data_format = data_format - self.normalized_shape = normalized_shape - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError(f"Unsupported data format: {self.data_format}") - - def build(self, input_shape): - self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight") - self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias") - super().build(input_shape) - - def call(self, x: tf.Tensor) -> tf.Tensor: - if self.data_format == "channels_last": - x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1) - elif self.data_format == "channels_first": - x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1) - return x - - -class TFSamAttention(keras.layers.Layer): - """ - SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and - values. - """ - - def __init__(self, config, downsample_rate=None, **kwargs): - super().__init__(**kwargs) - self.hidden_size = config.hidden_size - - downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate - - self.internal_dim = config.hidden_size // downsample_rate - self.num_attention_heads = config.num_attention_heads - if self.internal_dim % config.num_attention_heads != 0: - raise ValueError("num_attention_heads must divide hidden_size.") - - self.q_proj = keras.layers.Dense(self.internal_dim, name="q_proj") - self.k_proj = keras.layers.Dense(self.internal_dim, name="k_proj") - self.v_proj = keras.layers.Dense(self.internal_dim, name="v_proj") - self.out_proj = keras.layers.Dense(self.hidden_size, name="out_proj") - - def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor: - batch, point_batch_size, n_tokens, channel = shape_list(hidden_states) - c_per_head = channel // num_attention_heads - hidden_states = tf.reshape( - hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) - ) - return tf.transpose(hidden_states, perm=[0, 2, 1, 3]) - - def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor: - batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states) - hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3]) - return tf.reshape( - hidden_states, - (batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head), - ) - - def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: - # Input projections - query = self.q_proj(query) - key = self.k_proj(key) - value = self.v_proj(value) - - point_batch_size = shape_list(query)[1] - # Separate into heads - query = self._separate_heads(query, self.num_attention_heads) - key = self._separate_heads(key, self.num_attention_heads) - value = self._separate_heads(value, self.num_attention_heads) - - # SamAttention - _, _, _, c_per_head = shape_list(query) - attn = tf.matmul( - query, tf.transpose(key, perm=[0, 1, 3, 2]) - ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens - attn = attn / tf.math.sqrt(float(c_per_head)) - attn = tf.nn.softmax(attn, axis=-1) - - # Get output - out = tf.matmul(attn, value) - out = self._recombine_heads(out, point_batch_size) - out = self.out_proj(out) - - return out - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.hidden_size]) - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.hidden_size]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.hidden_size]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.internal_dim]) - - -class TFSamTwoWayAttentionBlock(keras.layers.Layer): - def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs): - """ - A transformer block with four layers: - (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on - sparse inputs (4) cross attention of dense inputs -> sparse inputs - - Arguments: - config (`SamMaskDecoderConfig`): - The configuration file used to instantiate the block - attention_downsample_rate (*optionalk*, int, defaults to 2): - The downsample ratio of the block used to reduce the inner dim of the attention. - skip_first_layer_pe (*optional*, bool, defaults to `False`): - Whether or not to skip the addition of the query_point_embedding on the first layer. - """ - super().__init__(**kwargs) - - self.hidden_size = config.hidden_size - self.layer_norm_eps = config.layer_norm_eps - - self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn") - self.layer_norm1 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1") - - self.cross_attn_token_to_image = TFSamAttention( - config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image" - ) - self.layer_norm2 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2") - - self.mlp = TFSamMLPBlock(config, name="mlp") - self.layer_norm3 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3") - - self.layer_norm4 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4") - self.cross_attn_image_to_token = TFSamAttention( - config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token" - ) - - self.skip_first_layer_pe = skip_first_layer_pe - - def call( - self, - queries: tf.Tensor, - keys: tf.Tensor, - query_point_embedding: tf.Tensor, - key_point_embedding: tf.Tensor, - output_attentions: bool = False, - ): - # Self attention block - if self.skip_first_layer_pe: - queries = self.self_attn(query=queries, key=queries, value=queries) - else: - query = queries + query_point_embedding - attn_out = self.self_attn(query=query, key=query, value=queries) - queries = queries + attn_out - queries = self.layer_norm1(queries) - - # Cross attention block, tokens attending to image embedding - query = queries + query_point_embedding - key = keys + key_point_embedding - - attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys) - queries = queries + attn_out - - queries = self.layer_norm2(queries) - - # MLP block - mlp_out = self.mlp(queries) - queries = queries + mlp_out - queries = self.layer_norm3(queries) - - # Cross attention block, image embedding attending to tokens - query = queries + query_point_embedding - key = keys + key_point_embedding - - attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) - keys = keys + attn_out - - keys = self.layer_norm4(keys) - - outputs = (queries, keys) - - if output_attentions: - outputs = outputs + (attn_out,) - else: - outputs = outputs + (None,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "layer_norm1", None) is not None: - with tf.name_scope(self.layer_norm1.name): - self.layer_norm1.build([None, None, None, self.hidden_size]) - if getattr(self, "cross_attn_token_to_image", None) is not None: - with tf.name_scope(self.cross_attn_token_to_image.name): - self.cross_attn_token_to_image.build(None) - if getattr(self, "layer_norm2", None) is not None: - with tf.name_scope(self.layer_norm2.name): - self.layer_norm2.build([None, None, None, self.hidden_size]) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "layer_norm3", None) is not None: - with tf.name_scope(self.layer_norm3.name): - self.layer_norm3.build([None, None, None, self.hidden_size]) - if getattr(self, "layer_norm4", None) is not None: - with tf.name_scope(self.layer_norm4.name): - self.layer_norm4.build([None, None, None, self.hidden_size]) - if getattr(self, "cross_attn_image_to_token", None) is not None: - with tf.name_scope(self.cross_attn_image_to_token.name): - self.cross_attn_image_to_token.build(None) - - -class TFSamTwoWayTransformer(keras.layers.Layer): - def __init__(self, config: SamMaskDecoderConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - - self.num_hidden_layers = config.num_hidden_layers - self.layers = [] - - for i in range(self.num_hidden_layers): - self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}")) - - self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image") - self.layer_norm_final_attn = keras.layers.LayerNormalization( - epsilon=config.layer_norm_eps, name="layer_norm_final_attn" - ) - - def call( - self, - point_embeddings: tf.Tensor, - image_embeddings: tf.Tensor, - image_positional_embeddings: tf.Tensor, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - ) -> tuple | TFBaseModelOutput: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - all_attentions = () - - if image_embeddings is None: - raise ValueError("You have to specify an image_embedding") - - image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None] - image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None] - - # Prepare queries - queries = point_embeddings - keys = image_embeddings - - # Apply transformer blocks and final layernorm - for layer in self.layers: - queries, keys, attention_outputs = layer( - queries=queries, - keys=keys, - query_point_embedding=point_embeddings, - key_point_embedding=image_positional_embeddings, - output_attentions=output_attentions, - ) - - if output_attentions: - all_attentions = all_attentions + (attention_outputs,) - - # Apply the final attention layer from the points to the image - query = queries + point_embeddings - key = keys + image_positional_embeddings - - attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) - - queries = queries + attn_out - queries = self.layer_norm_final_attn(queries) - return queries, keys, all_attentions - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "final_attn_token_to_image", None) is not None: - with tf.name_scope(self.final_attn_token_to_image.name): - self.final_attn_token_to_image.build(None) - if getattr(self, "layer_norm_final_attn", None) is not None: - with tf.name_scope(self.layer_norm_final_attn.name): - self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size]) - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFSamFeedForward(keras.layers.Layer): - def __init__( - self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs - ): - super().__init__(**kwargs) - self.num_layers = num_layers - self.activation = keras.layers.ReLU() - self.proj_in = keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in") - self.proj_out = keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out") - self.layers = [ - keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}") - for i in range(num_layers - 2) - ] - self.sigmoid_output = sigmoid_output - self.hidden_dim = hidden_dim - self.input_dim = input_dim - - def call(self, hidden_states): - hidden_states = self.proj_in(hidden_states) - hidden_states = self.activation(hidden_states) - for layer in self.layers: - hidden_states = self.activation(layer(hidden_states)) - - hidden_states = self.proj_out(hidden_states) - if self.sigmoid_output: - hidden_states = tf.sigmoid(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "proj_in", None) is not None: - with tf.name_scope(self.proj_in.name): - self.proj_in.build([None, None, self.input_dim]) - if getattr(self, "proj_out", None) is not None: - with tf.name_scope(self.proj_out.name): - self.proj_out.build([None, None, self.hidden_dim]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build([None, None, self.hidden_dim]) - - -class TFSamMaskDecoder(keras.layers.Layer): - def __init__(self, config: SamMaskDecoderConfig, **kwargs): - super().__init__(**kwargs) - - self.hidden_size = config.hidden_size - - self.num_multimask_outputs = config.num_multimask_outputs - self.num_mask_tokens = config.num_multimask_outputs + 1 - - self.transformer = TFSamTwoWayTransformer(config, name="transformer") - - self.upscale_conv1 = keras.layers.Conv2DTranspose( - self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first" - ) - self.upscale_conv2 = keras.layers.Conv2DTranspose( - self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first" - ) - self.upscale_layer_norm = TFSamLayerNorm( - self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm" - ) - self.activation = tf.nn.gelu - - mlps_list = [] - for i in range(self.num_mask_tokens): - mlps_list += [ - TFSamFeedForward( - self.hidden_size, - self.hidden_size, - self.hidden_size // 8, - 3, - name=f"output_hypernetworks_mlps_._{i}", - ) - ] - self.output_hypernetworks_mlps = mlps_list - - self.iou_prediction_head = TFSamFeedForward( - self.hidden_size, - config.iou_head_hidden_dim, - self.num_mask_tokens, - config.iou_head_depth, - name="iou_prediction_head", - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True) - self.mask_tokens = self.add_weight( - shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True - ) - - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "upscale_conv1", None) is not None: - with tf.name_scope(self.upscale_conv1.name): - self.upscale_conv1.build([None, self.hidden_size, None, None]) - if getattr(self, "upscale_conv2", None) is not None: - with tf.name_scope(self.upscale_conv2.name): - self.upscale_conv2.build([None, self.hidden_size // 4, None, None]) - if getattr(self, "upscale_layer_norm", None) is not None: - with tf.name_scope(self.upscale_layer_norm.name): - self.upscale_layer_norm.build(None) - if getattr(self, "iou_prediction_head", None) is not None: - with tf.name_scope(self.iou_prediction_head.name): - self.iou_prediction_head.build(None) - for mlp in self.output_hypernetworks_mlps: - with tf.name_scope(mlp.name): - mlp.build(None) - - def call( - self, - image_embeddings: tf.Tensor, - image_positional_embeddings: tf.Tensor, - sparse_prompt_embeddings: tf.Tensor, - dense_prompt_embeddings: tf.Tensor, - multimask_output: bool, - output_attentions: bool | None = None, - ) -> tuple[tf.Tensor, tf.Tensor]: - batch_size, num_channels, height, width = shape_list(image_embeddings) - point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1]) - - output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32) - output_tokens = tf.tile( - output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1] - ) # Should be (batch_size, point_size, 5, 32) - - # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only - # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced - # it with an explicit shape check to avoid data-dependent control flow which breaks XLA. - if shape_list(sparse_prompt_embeddings)[1] != 0: - tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2) - else: - tokens = output_tokens - point_embeddings = tf.cast(tokens, self.iou_token.dtype) - - image_embeddings = image_embeddings + dense_prompt_embeddings - image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0) - image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0) - - point_embedding, image_embeddings, attentions = self.transformer( - point_embeddings=point_embeddings, - image_embeddings=image_embeddings, - image_positional_embeddings=image_positional_embeddings, - output_attentions=output_attentions, - ) - iou_token_out = point_embedding[:, :, 0, :] - mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] - - image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2)) - image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width]) - - upscaled_embedding = self.upscale_conv1(image_embeddings) - upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) - upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) - - hyper_in_list = [] - for i in range(self.num_mask_tokens): - current_mlp = self.output_hypernetworks_mlps[i] - hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] - hyper_in = tf.stack(hyper_in_list, axis=2) - - _, num_channels, height, width = shape_list(upscaled_embedding) - upscaled_embedding = tf.reshape( - upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width] - ) - masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width]) - - iou_pred = self.iou_prediction_head(iou_token_out) - - if multimask_output: - mask_slice = slice(1, None) - else: - mask_slice = slice(0, 1) - masks = masks[:, :, mask_slice, :, :] - iou_pred = iou_pred[:, :, mask_slice] - - outputs = (masks, iou_pred) - - if output_attentions: - outputs = outputs + (attentions,) - else: - outputs = outputs + (None,) - - return outputs - - -class TFSamPositionalEmbedding(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.scale = config.hidden_size // 2 - self.config = config - - def build(self, input_shape): - # TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized? - self.positional_embedding = self.add_weight( - name="positional_embedding", - shape=(2, self.config.num_pos_feats), - initializer=keras.initializers.RandomNormal(mean=0.0, stddev=self.scale), - trainable=False, - ) - super().build(input_shape) - - def call(self, input_coords, input_shape=None): - """Positionally encode points that are normalized to [0,1].""" - coordinates = tf.identity(input_coords) - - if input_shape is not None: - coordinates = tf.stack( - [ - tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1], - tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0], - ], - axis=-1, - ) - - # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape - coordinates = 2 * coordinates - 1 - coordinates = tf.cast(coordinates, self.positional_embedding.dtype) - coordinates = tf.matmul(coordinates, self.positional_embedding) - coordinates = 2 * np.pi * coordinates - # outputs d_1 x ... x d_n x channel shape - return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1) - - -class TFSamMaskEmbedding(keras.layers.Layer): - def __init__(self, config: SamPromptEncoderConfig, **kwargs): - super().__init__(**kwargs) - self.mask_input_channels = config.mask_input_channels // 4 - self.activation = ACT2FN[config.hidden_act] - self.conv1 = keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1") - self.conv2 = keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2") - self.conv3 = keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3") - self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1") - self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2") - self.config = config - - def call(self, masks): - masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last - hidden_states = self.conv1(masks) - hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.activation(hidden_states) - - hidden_states = self.conv2(hidden_states) - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.activation(hidden_states) - dense_embeddings = self.conv3(hidden_states) - dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first - return dense_embeddings - - def build(self, input_shape=None): - # This class needs an explicit build method because it isn't called with the standard dummy inputs - if self.built: - return - self.built = True - with tf.name_scope("conv1"): - self.conv1.build([None, None, None, 1]) - with tf.name_scope("conv2"): - self.conv2.build([None, None, None, self.mask_input_channels]) - with tf.name_scope("conv3"): - self.conv3.build([None, None, None, self.mask_input_channels * 4]) - with tf.name_scope("layer_norm1"): - self.layer_norm1.build([None, None, None, self.mask_input_channels]) - with tf.name_scope("layer_norm2"): - self.layer_norm2.build([None, None, None, self.mask_input_channels * 4]) - - -class TFSamPromptEncoder(keras.layers.Layer): - def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs): - super().__init__(**kwargs) - self.shared_embedding = shared_patch_embedding - self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed") - self.no_mask_embed = None - - self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size) - self.input_image_size = config.image_size - - self.point_embed = [] - self.hidden_size = config.hidden_size - self.not_a_point_embed = None - self.config = config - - def build(self, input_shape=None): - self.no_mask_embed = self.add_weight( - name="no_mask_embed.weight", - shape=(1, self.hidden_size), - initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), - trainable=True, - ) - self.point_embed = [ - self.add_weight( - name=f"point_embed_._{i}.weight", - shape=(1, self.hidden_size), - initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), - trainable=True, - ) - for i in range(self.config.num_point_embeddings) - ] - self.not_a_point_embed = self.add_weight( - name="not_a_point_embed.weight", - shape=(1, self.hidden_size), - initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), - trainable=True, - ) - with tf.name_scope("mask_embed"): - # We must explicitly build the mask embed because it isn't touched by the standard dummy inputs - self.mask_embed.build( - (None, self.config.mask_input_channels, self.config.image_size, self.config.image_size) - ) - - if self.built: - return - self.built = True - if getattr(self, "mask_embed", None) is not None: - with tf.name_scope(self.mask_embed.name): - self.mask_embed.build(None) - - def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor: - """Embeds point prompts.""" - points = points + 0.5 # Shift to center of pixel - if pad: - target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1]) - target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1) - padding_point = tf.zeros(target_point_shape, dtype=points.dtype) - padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype) - points = tf.concat([points, padding_point], axis=2) - labels = tf.concat([labels, padding_label], axis=2) - input_shape = (self.input_image_size, self.input_image_size) - point_embedding = self.shared_embedding(points, input_shape) - - point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding) - - point_embedding = tf.where( - labels[..., None] != -10, - point_embedding, - tf.zeros_like(point_embedding), - ) - point_embedding = tf.where( - (labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding - ) - point_embedding = tf.where( - (labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding - ) - return point_embedding - - def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: - """Embeds box prompts.""" - boxes = boxes + 0.5 # Shift to center of pixel - batch_size, nb_boxes = shape_list(boxes)[:2] - coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2)) - input_shape = (self.input_image_size, self.input_image_size) - corner_embedding = self.shared_embedding(coords, input_shape) - corner_embedding += tf.where( - tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0, - self.point_embed[2][0], - self.point_embed[3][0], - ) - return corner_embedding - - def call( - self, - batch_size: int | None, - input_points: tuple[tf.Tensor, tf.Tensor] | None, - input_labels: tf.Tensor | None, - input_boxes: tf.Tensor | None, - input_masks: tf.Tensor | None, - ) -> tuple[tf.Tensor, tf.Tensor]: - """ - Embeds different types of prompts, returning both sparse and dense embeddings. - - Args: - points (`tf.Tensor`, *optional*): - point coordinates and labels to embed. - boxes (`tf.Tensor`, *optional*): - boxes to embed - masks (`tf.Tensor`, *optional*): - masks to embed - """ - sparse_embeddings = None - if input_points is not None: - batch_size, point_batch_size = shape_list(input_points)[:2] - if input_labels is None: - raise ValueError("If points are provided, labels must also be provided.") - point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) - sparse_embeddings = tf.zeros( - (batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype - ) - sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2) - if input_boxes is not None: - batch_size = shape_list(input_boxes)[0] - box_embeddings = self._embed_boxes(input_boxes) - if sparse_embeddings is None: - sparse_embeddings = box_embeddings - else: - sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2) - if input_masks is not None: - dense_embeddings = self.mask_embed(input_masks) - else: - dense_embeddings = self.no_mask_embed[0] - dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1)) - dense_embeddings = tf.tile( - dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1]) - ) - if sparse_embeddings is None: - sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype) - - return sparse_embeddings, dense_embeddings - - -class TFSamVisionAttention(keras.layers.Layer): - """Multi-head Attention block with relative position embeddings.""" - - def __init__(self, config, window_size, **kwargs): - super().__init__(**kwargs) - input_size = ( - (config.image_size // config.patch_size, config.image_size // config.patch_size) - if window_size == 0 - else (window_size, window_size) - ) - self.input_size = input_size - - self.num_attention_heads = config.num_attention_heads - head_dim = config.hidden_size // config.num_attention_heads - self.head_dim = head_dim - self.scale = head_dim**-0.5 - self.dropout = config.attention_dropout - - self.qkv = keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv") - self.proj = keras.layers.Dense(config.hidden_size, name="proj") - - self.use_rel_pos = config.use_rel_pos - if self.use_rel_pos: - if input_size is None: - raise ValueError("Input size must be provided if using relative positional encoding.") - self.config = config - - def build(self, input_shape=None): - if self.input_size is not None: - # initialize relative positional embeddings - self.rel_pos_h = self.add_weight( - shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h" - ) - self.rel_pos_w = self.add_weight( - shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w" - ) - - if self.built: - return - self.built = True - if getattr(self, "qkv", None) is not None: - with tf.name_scope(self.qkv.name): - self.qkv.build([None, None, self.config.hidden_size]) - if getattr(self, "proj", None) is not None: - with tf.name_scope(self.proj.name): - self.proj.build([None, None, self.config.hidden_size]) - - def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor: - """ - Get relative positional embeddings according to the relative positions of - query and key sizes. - - Args: - q_size (int): - size of the query. - k_size (int): - size of key k. - rel_pos (`tf.Tensor`): - relative position embeddings (L, channel). - - Returns: - Extracted positional embeddings according to relative positions. - """ - max_rel_dist = int(2 * max(q_size, k_size) - 1) - # Interpolate rel pos if needed. - if rel_pos.shape[0] != max_rel_dist: - # Interpolate rel pos. - rel_pos_resized = tf.image.resize( - tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)), - size=(max_rel_dist, rel_pos.shape[1]), - method="bilinear", - ) - rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist)) - else: - rel_pos_resized = rel_pos - - # Scale the coords with short length if shapes for q and k are different. - q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0) - k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0) - relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) - - return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32)) - - def get_decomposed_rel_pos( - self, - query: tf.Tensor, - rel_pos_h: tf.Tensor, - rel_pos_w: tf.Tensor, - q_size: tuple[int, int], - k_size: tuple[int, int], - ) -> tf.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py - - Args: - query (`tf.Tensor`): - query q in the attention layer with shape (batch_size, query_height * query_width, channel). - rel_pos_h (`tf.Tensor`): - relative position embeddings (Lh, channel) for height axis. - rel_pos_w (`tf.Tensor`): - relative position embeddings (Lw, channel) for width axis. - q_size (tuple): - spatial sequence size of query q with (query_height, query_width). - k_size (tuple): - spatial sequence size of key k with (key_height, key_width). - - Returns: - decomposed_rel_pos (`torch.Tensor`): - decomposed relative position embeddings. - """ - query_height, query_width = q_size - key_height, key_width = k_size - relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) - relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) - - batch_size, _, dim = shape_list(query) - reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim)) - rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) - rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) - - rel_h = tf.expand_dims(rel_h, axis=-1) - rel_w = tf.expand_dims(rel_w, axis=-2) - decomposed_rel_pos = rel_h + rel_w - - return decomposed_rel_pos - - def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor: - batch_size, height, width, _ = shape_list(hidden_states) - # qkv with shape (3, batch_size, nHead, height * width, channel) - qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1)) - qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4)) - # q, k, v with shape (batch_size * nHead, height * width, channel) - query, key, value = tf.unstack( - tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0 - ) - attn_weights = tf.matmul(query * self.scale, key, transpose_b=True) - - if self.use_rel_pos: - decomposed_rel_pos = self.get_decomposed_rel_pos( - query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) - ) - decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights)) - attn_weights = attn_weights + decomposed_rel_pos - - attn_weights = tf.nn.softmax(attn_weights, axis=-1) - - if training: - attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) - else: - attn_probs = attn_weights - - attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1)) - attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4)) - attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size)) - - attn_output = self.proj(attn_output) - - if output_attentions: - outputs = (attn_output, attn_weights) - else: - outputs = (attn_output, None) - - return outputs - - -class TFSamVisionLayer(keras.layers.Layer): - def __init__(self, config, window_size, **kwargs): - super().__init__(**kwargs) - self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") - self.attn = TFSamVisionAttention(config, window_size, name="attn") - self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") - self.mlp = TFSamMLPBlock(config, name="mlp") - self.window_size = window_size - self.config = config - - def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> tuple[tf.Tensor, tuple[int, int]]: - batch_size, height, width, channel = shape_list(hidden_states) - - pad_h = (window_size - height % window_size) % window_size - pad_w = (window_size - width % window_size) % window_size - if pad_h > 0 or pad_w > 0: - hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]]) - pad_height, pad_width = height + pad_h, width + pad_w - - hidden_states = tf.reshape( - hidden_states, - [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel], - ) - windows = tf.reshape( - tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel] - ) - return windows, (pad_height, pad_width) - - def window_unpartition( - self, windows: tf.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int] - ) -> tf.Tensor: - pad_height, pad_width = padding_shape - height, width = original_shape - batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size) - hidden_states = tf.reshape( - windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1] - ) - hidden_states = tf.reshape( - tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1] - ) - - if pad_height > height or pad_width > width: - hidden_states = hidden_states[:, :height, :width, :] - return hidden_states - - def call( - self, - hidden_states: tf.Tensor, - output_attentions: bool | None = False, - training: bool | None = False, - ) -> tuple[tf.Tensor]: - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - if self.window_size > 0: - height, width = hidden_states.shape[1], hidden_states.shape[2] - hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) - - hidden_states, attn_weights = self.attn( - hidden_states=hidden_states, - output_attentions=output_attentions, - training=training, - ) - if self.window_size > 0: - hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) - - hidden_states = residual + hidden_states - layernorm_output = self.layer_norm2(hidden_states) - hidden_states = hidden_states + self.mlp(layernorm_output) - - outputs = (hidden_states,) - if output_attentions: - outputs += (attn_weights,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer_norm1", None) is not None: - with tf.name_scope(self.layer_norm1.name): - self.layer_norm1.build([None, None, None, self.config.hidden_size]) - if getattr(self, "attn", None) is not None: - with tf.name_scope(self.attn.name): - self.attn.build(None) - if getattr(self, "layer_norm2", None) is not None: - with tf.name_scope(self.layer_norm2.name): - self.layer_norm2.build([None, None, None, self.config.hidden_size]) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - - -class TFSamVisionNeck(keras.layers.Layer): - def __init__(self, config: SamVisionConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - - self.conv1 = keras.layers.Conv2D( - config.output_channels, - kernel_size=1, - use_bias=False, - name="conv1", - ) - self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1") - self.conv2 = keras.layers.Conv2D( - config.output_channels, - kernel_size=3, - padding="same", - use_bias=False, - name="conv2", - ) - self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2") - - def call(self, hidden_states): - hidden_states = self.conv1(hidden_states) - hidden_states = self.layer_norm1(hidden_states) - - hidden_states = self.conv2(hidden_states) - hidden_states = self.layer_norm2(hidden_states) - hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2]) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv1", None) is not None: - with tf.name_scope(self.conv1.name): - self.conv1.build([None, None, None, self.config.hidden_size]) - if getattr(self, "layer_norm1", None) is not None: - with tf.name_scope(self.layer_norm1.name): - self.layer_norm1.build(None) - if getattr(self, "conv2", None) is not None: - with tf.name_scope(self.conv2.name): - self.conv2.build([None, None, None, self.config.output_channels]) - if getattr(self, "layer_norm2", None) is not None: - with tf.name_scope(self.layer_norm2.name): - self.layer_norm2.build(None) - - -class TFSamVisionEncoder(keras.layers.Layer): - def __init__(self, config: SamVisionConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.image_size = config.image_size - - self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed") - - self.pos_embed = None - - self.layers = [] - for i in range(config.num_hidden_layers): - layer = TFSamVisionLayer( - config, - window_size=config.window_size if i not in config.global_attn_indexes else 0, - name=f"layers_._{i}", - ) - self.layers.append(layer) - - self.neck = TFSamVisionNeck(config, name="neck") - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if self.config.use_abs_pos: - # Initialize absolute positional embedding with pretrain image size. - self.pos_embed = self.add_weight( - shape=[ - 1, - self.config.image_size // self.config.patch_size, - self.config.image_size // self.config.patch_size, - self.config.hidden_size, - ], - initializer="zeros", - trainable=True, - name="pos_embed", - ) - - if getattr(self, "patch_embed", None) is not None: - with tf.name_scope(self.patch_embed.name): - self.patch_embed.build(None) - if getattr(self, "neck", None) is not None: - with tf.name_scope(self.neck.name): - self.neck.build(None) - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - def get_input_embeddings(self): - return self.patch_embed - - def call( - self, - pixel_values: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFSamVisionEncoderOutput: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - hidden_states = self.patch_embed(pixel_values) - if self.pos_embed is not None: - hidden_states = hidden_states + self.pos_embed - - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - for i, layer_module in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - hidden_states = self.neck(hidden_states) - - if not return_dict: - outputs = (hidden_states,) - if output_hidden_states: - outputs = outputs + (all_hidden_states,) - if output_attentions: - outputs = outputs + (all_self_attentions,) - return outputs - - return TFSamVisionEncoderOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class TFSamPreTrainedModel(TFPreTrainedModel): - config_class = SamConfig - base_model_prefix = "sam" - main_input_name = "pixel_values" - - -SAM_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a TensorFlow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) - subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to - general usage and behavior. - - Parameters: - config ([`SamConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -SAM_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for - details. - input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`): - Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much - better results. The points can be obtained by passing a list of list of list to the processor that will - create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second - dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per - input point), the third dimension is the number of points per segmentation mask (it is possible to pass - multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) - coordinates of the point. If a different number of points is passed either for each image, or for each - mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the - computation of the embedding will be skipped for these points using the labels. - input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`): - Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the - official implementation, there are 3 types of labels - - - `1`: the point is a point that contains the object of interest - - `0`: the point is a point that does not contain the object of interest - - `-1`: the point corresponds to the background - - We added the label: - - - `-10`: the point is a padding point, thus should be ignored by the prompt encoder - - The padding labels should be automatically done by the processor. - input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`): - Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to - much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, - that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size, - the number of boxes per image and the coordinates of the top left and bottom right point of the box. In the - order (`x1`, `y1`, `x2`, `y2`): - - - `x1`: the x coordinate of the top left point of the input box - - `y1`: the y coordinate of the top left point of the input box - - `x2`: the x coordinate of the bottom right point of the input box - - `y2`: the y coordinate of the bottom right point of the input box - - input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): - SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to - generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be - manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). - - image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`): - Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory - efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` - method, and then feed them to the `call` method instead of feeding the `pixel_values`. - multimask_output (`bool`, *optional*): - In the original implementation and paper, the model always outputs 3 masks per image (or per point / per - bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the - "best" mask, by specifying `multimask_output=False`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -SAM_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for - details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - """The vision model from Sam without any head or projection on top.""", - SAM_START_DOCSTRING, -) -class TFSamVisionModel(TFSamPreTrainedModel): - config_class = SamVisionConfig - main_input_name = "pixel_values" - - def __init__(self, config: SamVisionConfig, **kwargs): - super().__init__(config, **kwargs) - self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder") - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "vision_encoder", None) is not None: - with tf.name_scope(self.vision_encoder.name): - self.vision_encoder.build(None) - - def get_input_embeddings(self): - return self.vision_encoder.patch_embed - - @unpack_inputs - @add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig) - def call( - self, - pixel_values: TFModelInputType | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, - ) -> TFSamVisionEncoderOutput | tuple[tf.Tensor]: - r""" - Returns: - - """ - return self.vision_encoder( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - -@add_start_docstrings( - "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", - " optional 2D location and bounding boxes.", - SAM_START_DOCSTRING, -) -class TFSamModel(TFSamPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"] - - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding") - - self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder") - self.prompt_encoder = TFSamPromptEncoder( - config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder" - ) - self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder") - self.config = config - - def get_input_embeddings(self): - return self.vision_encoder.get_input_embeddings() - - def get_image_wide_positional_embeddings(self): - size = self.config.prompt_encoder_config.image_embedding_size - grid = tf.ones((size, size)) - y_embed = tf.math.cumsum(grid, axis=0) - 0.5 - x_embed = tf.math.cumsum(grid, axis=1) - 0.5 - y_embed = y_embed / size - x_embed = x_embed / size - - positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1)) - return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) # channel x height x width - - def get_image_embeddings( - self, - pixel_values, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - ): - r""" - Returns the image embeddings by passing the pixel values through the vision encoder. - - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Input pixel values - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple. - - """ - vision_output = self.vision_encoder( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - image_embeddings = vision_output[0] - return image_embeddings - - def get_prompt_embeddings( - self, - input_points: tf.Tensor | None = None, - input_labels: tf.Tensor | None = None, - input_boxes: tf.Tensor | None = None, - input_masks: tf.Tensor | None = None, - ): - r""" - Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. - - Args: - input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): - Optional input points for the prompt encoder. The padding of the point is automatically done by the - processor. `point_batch_size` refers to the number of masks that we want the model to predict per - point. The model will output `point_batch_size` times 3 masks in total. - input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): - Optional input labels for the prompt encoder. The padding of the labels is automatically done by the - processor, or can be fed by the user. - input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`): - Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the - processor. users can also pass manually the input boxes. - input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): - Optional input masks for the prompt encoder. - """ - prompt_output = self.prompt_encoder( - input_points=input_points, - input_labels=input_labels, - input_boxes=input_boxes, - input_masks=input_masks, - ) - return prompt_output - - @unpack_inputs - @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING) - def call( - self, - pixel_values: TFModelInputType | None = None, - input_points: tf.Tensor | None = None, - input_labels: tf.Tensor | None = None, - input_boxes: tf.Tensor | None = None, - input_masks: tf.Tensor | None = None, - image_embeddings: tf.Tensor | None = None, - multimask_output: bool = True, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, - ) -> TFSamImageSegmentationOutput | tuple[tf.Tensor]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None and image_embeddings is None: - raise ValueError("Either pixel_values or image_embeddings must be provided.") - - if pixel_values is not None and image_embeddings is not None: - raise ValueError("Only one of pixel_values and image_embeddings can be provided.") - - if input_points is not None and len(input_points.shape) != 4: - raise ValueError( - "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.", - f" got {input_points.shape}.", - ) - if input_boxes is not None and len(input_boxes.shape) != 3: - raise ValueError( - "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", - f" got {input_boxes.shape}.", - ) - if input_points is not None and input_boxes is not None: - point_batch_size = shape_list(input_points)[1] - box_batch_size = shape_list(input_boxes)[1] - if point_batch_size != box_batch_size: - raise ValueError( - f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}." - ) - if pixel_values is not None: - # Ensures that later checks pass even with an all-None shape from the serving signature - pixel_values = tf.ensure_shape( - pixel_values, - [ - None, - self.config.vision_config.num_channels, - self.config.vision_config.image_size, - self.config.vision_config.image_size, - ], - ) - image_positional_embeddings = self.get_image_wide_positional_embeddings() - # repeat with batch size - batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0] - image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0) - - vision_attentions = None - vision_hidden_states = None - - if pixel_values is not None: - vision_outputs = self.vision_encoder( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - training=training, - ) - image_embeddings = vision_outputs["last_hidden_state"] - - if output_hidden_states: - vision_hidden_states = vision_outputs["hidden_states"] - if output_attentions: - vision_attentions = vision_outputs["attentions"] - - if input_points is not None and input_labels is None: - input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32) - - if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: - raise ValueError( - "The batch size of the image embeddings and the input points must be the same. ", - f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.", - " if you want to pass multiple points for the same image, make sure that you passed ", - " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", - " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", - ) - - sparse_embeddings, dense_embeddings = self.prompt_encoder( - batch_size=shape_list(image_embeddings)[0], - input_points=input_points, - input_labels=input_labels, - input_boxes=input_boxes, - input_masks=input_masks, - ) - - low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( - image_embeddings=image_embeddings, - image_positional_embeddings=image_positional_embeddings, - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - output_attentions=output_attentions, - ) - - if not return_dict: - output = (iou_predictions, low_res_masks) - if output_hidden_states: - output = output + (vision_hidden_states,) - - if output_attentions: - output = output + (vision_attentions, mask_decoder_attentions) - return output - - return TFSamImageSegmentationOutput( - iou_scores=iou_predictions, - pred_masks=low_res_masks, - vision_hidden_states=vision_hidden_states, - vision_attentions=vision_attentions, - mask_decoder_attentions=mask_decoder_attentions, - ) - - def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput: - hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None - attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None - - return TFSamImageSegmentationOutput( - iou_scores=output.iou_scores, - pred_masks=output.pred_masks, - vision_hidden_states=hs if self.config.output_hidden_states else None, - vision_attentions=attns if self.config.output_attentions else None, - mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "shared_image_embedding", None) is not None: - with tf.name_scope(self.shared_image_embedding.name): - self.shared_image_embedding.build(None) - if getattr(self, "vision_encoder", None) is not None: - with tf.name_scope(self.vision_encoder.name): - self.vision_encoder.build(None) - if getattr(self, "prompt_encoder", None) is not None: - with tf.name_scope(self.prompt_encoder.name): - self.prompt_encoder.build(None) - if getattr(self, "mask_decoder", None) is not None: - with tf.name_scope(self.mask_decoder.name): - self.mask_decoder.build(None) - - -__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"] diff --git a/src/transformers/models/segformer/modeling_tf_segformer.py b/src/transformers/models/segformer/modeling_tf_segformer.py deleted file mode 100644 index 2f8e68b95748..000000000000 --- a/src/transformers/models/segformer/modeling_tf_segformer.py +++ /dev/null @@ -1,1044 +0,0 @@ -# coding=utf-8 -# Copyright 2022 NVIDIA The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow SegFormer model.""" - -from __future__ import annotations - -import math - -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...file_utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) -from ...modeling_tf_outputs import TFBaseModelOutput, TFSemanticSegmenterOutput, TFSequenceClassifierOutput -from ...modeling_tf_utils import ( - TFPreTrainedModel, - TFSequenceClassificationLoss, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list, stable_softmax -from ...utils import logging -from .configuration_segformer import SegformerConfig - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "SegformerConfig" - -# Base docstring -_CHECKPOINT_FOR_DOC = "nvidia/mit-b0" -_EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16] - -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0" -_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" - - -# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->Segformer -class TFSegformerDropPath(keras.layers.Layer): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - References: - (1) github.com:rwightman/pytorch-image-models - """ - - def __init__(self, drop_path: float, **kwargs): - super().__init__(**kwargs) - self.drop_path = drop_path - - def call(self, x: tf.Tensor, training=None): - if training: - keep_prob = 1 - self.drop_path - shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) - random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) - random_tensor = tf.floor(random_tensor) - return (x / keep_prob) * random_tensor - return x - - -class TFSegformerOverlapPatchEmbeddings(keras.layers.Layer): - """Construct the overlapping patch embeddings.""" - - def __init__(self, patch_size, stride, num_channels, hidden_size, **kwargs): - super().__init__(**kwargs) - self.padding = keras.layers.ZeroPadding2D(padding=patch_size // 2) - self.proj = keras.layers.Conv2D( - filters=hidden_size, kernel_size=patch_size, strides=stride, padding="VALID", name="proj" - ) - - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm") - self.num_channels = num_channels - self.hidden_size = hidden_size - - def call(self, pixel_values: tf.Tensor) -> tuple[tf.Tensor, int, int]: - embeddings = self.proj(self.padding(pixel_values)) - height = shape_list(embeddings)[1] - width = shape_list(embeddings)[2] - hidden_dim = shape_list(embeddings)[3] - # (batch_size, height, width, num_channels) -> (batch_size, height*width, num_channels) - # this can be fed to a Transformer layer - embeddings = tf.reshape(embeddings, (-1, height * width, hidden_dim)) - embeddings = self.layer_norm(embeddings) - return embeddings, height, width - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "proj", None) is not None: - with tf.name_scope(self.proj.name): - self.proj.build([None, None, None, self.num_channels]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.hidden_size]) - - -class TFSegformerEfficientSelfAttention(keras.layers.Layer): - """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT - paper](https://huggingface.co/papers/2102.12122).""" - - def __init__( - self, - config: SegformerConfig, - hidden_size: int, - num_attention_heads: int, - sequence_reduction_ratio: int, - **kwargs, - ): - super().__init__(**kwargs) - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - - if self.hidden_size % self.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " - f"heads ({self.num_attention_heads})" - ) - - self.attention_head_size = self.hidden_size // self.num_attention_heads - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense(self.all_head_size, name="query") - self.key = keras.layers.Dense(self.all_head_size, name="key") - self.value = keras.layers.Dense(self.all_head_size, name="value") - - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - - self.sr_ratio = sequence_reduction_ratio - if sequence_reduction_ratio > 1: - self.sr = keras.layers.Conv2D( - filters=hidden_size, kernel_size=sequence_reduction_ratio, strides=sequence_reduction_ratio, name="sr" - ) - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm") - - def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] - # to [batch_size, seq_length, num_attention_heads, attention_head_size] - batch_size = shape_list(tensor)[0] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] - # to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - height: int, - width: int, - output_attentions: bool = False, - training: bool = False, - ) -> tf.Tensor | tuple[tf.Tensor, tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - num_channels = shape_list(hidden_states)[2] - - query_layer = self.transpose_for_scores(self.query(hidden_states)) - - if self.sr_ratio > 1: - # Reshape to (batch_size, height, width, num_channels) - hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels)) - # Apply sequence reduction - hidden_states = self.sr(hidden_states) - # Reshape back to (batch_size, seq_len, num_channels) - hidden_states = tf.reshape(hidden_states, (batch_size, -1, num_channels)) - hidden_states = self.layer_norm(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - - scale = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, scale) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs, training=training) - - context_layer = tf.matmul(attention_probs, value_layer) - - context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) - # (batch_size, seq_len_q, all_head_size) - context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size)) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.hidden_size]) - if getattr(self, "sr", None) is not None: - with tf.name_scope(self.sr.name): - self.sr.build([None, None, None, self.hidden_size]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.hidden_size]) - - -class TFSegformerSelfOutput(keras.layers.Layer): - def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense(hidden_size, name="dense") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.hidden_size = hidden_size - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.hidden_size]) - - -class TFSegformerAttention(keras.layers.Layer): - def __init__( - self, - config: SegformerConfig, - hidden_size: int, - num_attention_heads: int, - sequence_reduction_ratio: int, - **kwargs, - ): - super().__init__(**kwargs) - self.self = TFSegformerEfficientSelfAttention( - config=config, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - sequence_reduction_ratio=sequence_reduction_ratio, - name="self", - ) - self.dense_output = TFSegformerSelfOutput(config, hidden_size=hidden_size, name="output") - - def call( - self, hidden_states: tf.Tensor, height: int, width: int, output_attentions: bool = False - ) -> tf.Tensor | tuple[tf.Tensor, tf.Tensor]: - self_outputs = self.self(hidden_states, height, width, output_attentions) - - attention_output = self.dense_output(self_outputs[0]) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self", None) is not None: - with tf.name_scope(self.self.name): - self.self.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -class TFSegformerDWConv(keras.layers.Layer): - def __init__(self, dim: int = 768, **kwargs): - super().__init__(**kwargs) - self.depthwise_convolution = keras.layers.Conv2D( - filters=dim, kernel_size=3, strides=1, padding="same", groups=dim, name="dwconv" - ) - self.dim = dim - - def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor: - batch_size = shape_list(hidden_states)[0] - num_channels = shape_list(hidden_states)[-1] - hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels)) - hidden_states = self.depthwise_convolution(hidden_states) - - new_height = shape_list(hidden_states)[1] - new_width = shape_list(hidden_states)[2] - num_channels = shape_list(hidden_states)[3] - hidden_states = tf.reshape(hidden_states, (batch_size, new_height * new_width, num_channels)) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "depthwise_convolution", None) is not None: - with tf.name_scope(self.depthwise_convolution.name): - self.depthwise_convolution.build([None, None, None, self.dim]) - - -class TFSegformerMixFFN(keras.layers.Layer): - def __init__( - self, - config: SegformerConfig, - in_features: int, - hidden_features: int | None = None, - out_features: int | None = None, - **kwargs, - ): - super().__init__(**kwargs) - out_features = out_features or in_features - self.dense1 = keras.layers.Dense(hidden_features, name="dense1") - self.depthwise_convolution = TFSegformerDWConv(hidden_features, name="dwconv") - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.dense2 = keras.layers.Dense(out_features, name="dense2") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.hidden_features = hidden_features - self.in_features = in_features - - def call(self, hidden_states: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor: - hidden_states = self.dense1(hidden_states) - hidden_states = self.depthwise_convolution(hidden_states, height, width) - hidden_states = self.intermediate_act_fn(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.dense2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense1", None) is not None: - with tf.name_scope(self.dense1.name): - self.dense1.build([None, None, self.in_features]) - if getattr(self, "depthwise_convolution", None) is not None: - with tf.name_scope(self.depthwise_convolution.name): - self.depthwise_convolution.build(None) - if getattr(self, "dense2", None) is not None: - with tf.name_scope(self.dense2.name): - self.dense2.build([None, None, self.hidden_features]) - - -class TFSegformerLayer(keras.layers.Layer): - """This corresponds to the Block class in the original implementation.""" - - def __init__( - self, - config, - hidden_size: int, - num_attention_heads: int, - drop_path: float, - sequence_reduction_ratio: int, - mlp_ratio: int, - **kwargs, - ): - super().__init__(**kwargs) - self.layer_norm_1 = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_1") - self.attention = TFSegformerAttention( - config, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - sequence_reduction_ratio=sequence_reduction_ratio, - name="attention", - ) - self.drop_path = TFSegformerDropPath(drop_path) if drop_path > 0.0 else keras.layers.Activation("linear") - self.layer_norm_2 = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_2") - mlp_hidden_size = int(hidden_size * mlp_ratio) - self.mlp = TFSegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size, name="mlp") - self.hidden_size = hidden_size - - def call( - self, - hidden_states: tf.Tensor, - height: int, - width: int, - output_attentions: bool = False, - training: bool = False, - ) -> tuple: - self_attention_outputs = self.attention( - self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention - height, - width, - output_attentions=output_attentions, - training=training, - ) - - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - # first residual connection (with stochastic depth) - attention_output = self.drop_path(attention_output, training=training) - hidden_states = attention_output + hidden_states - mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width) - - # second residual connection (with stochastic depth) - mlp_output = self.drop_path(mlp_output, training=training) - layer_output = mlp_output + hidden_states - - outputs = (layer_output,) + outputs - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer_norm_1", None) is not None: - with tf.name_scope(self.layer_norm_1.name): - self.layer_norm_1.build([None, None, self.hidden_size]) - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "layer_norm_2", None) is not None: - with tf.name_scope(self.layer_norm_2.name): - self.layer_norm_2.build([None, None, self.hidden_size]) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - - -class TFSegformerEncoder(keras.layers.Layer): - def __init__(self, config: SegformerConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - - # stochastic depth decay rule - drop_path_decays = [x.numpy() for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))] - - # patch embeddings - embeddings = [] - for i in range(config.num_encoder_blocks): - embeddings.append( - TFSegformerOverlapPatchEmbeddings( - patch_size=config.patch_sizes[i], - stride=config.strides[i], - num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], - hidden_size=config.hidden_sizes[i], - name=f"patch_embeddings.{i}", - ) - ) - self.embeddings = embeddings - - # Transformer blocks - blocks = [] - cur = 0 - for i in range(config.num_encoder_blocks): - # each block consists of layers - layers = [] - if i != 0: - cur += config.depths[i - 1] - for j in range(config.depths[i]): - layers.append( - TFSegformerLayer( - config, - hidden_size=config.hidden_sizes[i], - num_attention_heads=config.num_attention_heads[i], - drop_path=drop_path_decays[cur + j], - sequence_reduction_ratio=config.sr_ratios[i], - mlp_ratio=config.mlp_ratios[i], - name=f"block.{i}.{j}", - ) - ) - blocks.append(layers) - - self.block = blocks - - # Layer norms - self.layer_norms = [ - keras.layers.LayerNormalization(epsilon=1e-05, name=f"layer_norm.{i}") - for i in range(config.num_encoder_blocks) - ] - - def call( - self, - pixel_values: tf.Tensor, - output_attentions: bool | None = False, - output_hidden_states: bool | None = False, - return_dict: bool | None = True, - training: bool = False, - ) -> tuple | TFBaseModelOutput: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - batch_size = shape_list(pixel_values)[0] - - hidden_states = pixel_values - for idx, x in enumerate(zip(self.embeddings, self.block, self.layer_norms)): - embedding_layer, block_layer, norm_layer = x - # first, obtain patch embeddings - hidden_states, height, width = embedding_layer(hidden_states) - - # second, send embeddings through blocks - # (each block consists of multiple layers i.e., list of layers) - for i, blk in enumerate(block_layer): - layer_outputs = blk( - hidden_states, - height, - width, - output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - # third, apply layer norm - hidden_states = norm_layer(hidden_states) - - # fourth, optionally reshape back to (batch_size, height, width, num_channels) - if idx != len(self.embeddings) - 1 or (idx == len(self.embeddings) - 1 and self.config.reshape_last_stage): - num_channels = shape_list(hidden_states)[-1] - hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels)) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer_norms", None) is not None: - for layer, shape in zip(self.layer_norms, self.config.hidden_sizes): - with tf.name_scope(layer.name): - layer.build([None, None, shape]) - if getattr(self, "block", None) is not None: - for block in self.block: - for layer in block: - with tf.name_scope(layer.name): - layer.build(None) - if getattr(self, "embeddings", None) is not None: - for layer in self.embeddings: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFSegformerMainLayer(keras.layers.Layer): - config_class = SegformerConfig - - def __init__(self, config: SegformerConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - # hierarchical Transformer encoder - self.encoder = TFSegformerEncoder(config, name="encoder") - - @unpack_inputs - def call( - self, - pixel_values: tf.Tensor, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple | TFBaseModelOutput: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - # shape = (batch_size, in_height, in_width, in_channels=num_channels) - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - - encoder_outputs = self.encoder( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = encoder_outputs[0] - # Change to NCHW output format to have uniformity in the modules - sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2]) - - # Change the other hidden state outputs to NCHW as well - if output_hidden_states: - hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]) - - if not return_dict: - if tf.greater(len(encoder_outputs[1:]), 0): - transposed_encoder_outputs = tuple(tf.transpose(v, perm=[0, 3, 1, 2]) for v in encoder_outputs[1:][0]) - return (sequence_output,) + (transposed_encoder_outputs,) - else: - return (sequence_output,) + encoder_outputs[1:] - - return TFBaseModelOutput( - last_hidden_state=sequence_output, - hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - - -class TFSegformerPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = SegformerConfig - base_model_prefix = "segformer" - main_input_name = "pixel_values" - - @property - def input_signature(self): - return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 512, 512), dtype=tf.float32)} - - -SEGFORMER_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`SegformerConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -SEGFORMER_INPUTS_DOCSTRING = r""" - - Args: - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`SegformerImageProcessor.__call__`] for details. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.", - SEGFORMER_START_DOCSTRING, -) -class TFSegformerModel(TFSegformerPreTrainedModel): - def __init__(self, config: SegformerConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.config = config - - # hierarchical Transformer encoder - self.segformer = TFSegformerMainLayer(config, name="segformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - modality="vision", - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - def call( - self, - pixel_values: tf.Tensor, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple | TFBaseModelOutput: - outputs = self.segformer( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "segformer", None) is not None: - with tf.name_scope(self.segformer.name): - self.segformer.build(None) - - -@add_start_docstrings( - """ - SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden - states) e.g. for ImageNet. - """, - SEGFORMER_START_DOCSTRING, -) -class TFSegformerForImageClassification(TFSegformerPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: SegformerConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.segformer = TFSegformerMainLayer(config, name="segformer") - - # Classifier head - self.classifier = keras.layers.Dense(config.num_labels, name="classifier") - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, - ) - def call( - self, - pixel_values: tf.Tensor | None = None, - labels: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - ) -> tuple | TFSequenceClassifierOutput: - outputs = self.segformer( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - # convert last hidden states to (batch_size, height*width, hidden_size) - batch_size = shape_list(sequence_output)[0] - sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1]) - sequence_output = tf.reshape(sequence_output, (batch_size, -1, self.config.hidden_sizes[-1])) - - # global average pooling - sequence_output = tf.reduce_mean(sequence_output, axis=1) - - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "segformer", None) is not None: - with tf.name_scope(self.segformer.name): - self.segformer.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_sizes[-1]]) - - -class TFSegformerMLP(keras.layers.Layer): - """ - Linear Embedding. - """ - - def __init__(self, input_dim: int, config: SegformerConfig, **kwargs): - super().__init__(**kwargs) - self.proj = keras.layers.Dense(config.decoder_hidden_size, name="proj") - self.input_dim = input_dim - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - height = shape_list(hidden_states)[1] - width = shape_list(hidden_states)[2] - hidden_dim = shape_list(hidden_states)[-1] - hidden_states = tf.reshape(hidden_states, (-1, height * width, hidden_dim)) - hidden_states = self.proj(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "proj", None) is not None: - with tf.name_scope(self.proj.name): - self.proj.build([None, None, self.input_dim]) - - -class TFSegformerDecodeHead(TFSegformerPreTrainedModel): - def __init__(self, config: SegformerConfig, **kwargs): - super().__init__(config, **kwargs) - # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size - mlps = [] - for i in range(config.num_encoder_blocks): - mlp = TFSegformerMLP(config=config, input_dim=config.hidden_sizes[i], name=f"linear_c.{i}") - mlps.append(mlp) - self.mlps = mlps - - # the following 3 layers implement the ConvModule of the original implementation - self.linear_fuse = keras.layers.Conv2D( - filters=config.decoder_hidden_size, kernel_size=1, use_bias=False, name="linear_fuse" - ) - self.batch_norm = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="batch_norm") - self.activation = keras.layers.Activation("relu") - - self.dropout = keras.layers.Dropout(config.classifier_dropout_prob) - self.classifier = keras.layers.Conv2D(filters=config.num_labels, kernel_size=1, name="classifier") - - self.config = config - - def call(self, encoder_hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - all_hidden_states = () - for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps): - if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3: - height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32)) - height = width = tf.cast(height, tf.int32) - channel_dim = shape_list(encoder_hidden_state)[-1] - encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim)) - - # unify channel dimension - encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1]) - height, width = shape_list(encoder_hidden_state)[1:3] - encoder_hidden_state = mlp(encoder_hidden_state) - channel_dim = shape_list(encoder_hidden_state)[-1] - encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim)) - - # upsample - temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1]) - upsample_resolution = shape_list(temp_state)[1:-1] - encoder_hidden_state = tf.image.resize(encoder_hidden_state, size=upsample_resolution, method="bilinear") - all_hidden_states += (encoder_hidden_state,) - - hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1)) - hidden_states = self.batch_norm(hidden_states, training=training) - hidden_states = self.activation(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - - # logits of shape (batch_size, height/4, width/4, num_labels) - logits = self.classifier(hidden_states) - - return logits - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "linear_fuse", None) is not None: - with tf.name_scope(self.linear_fuse.name): - self.linear_fuse.build( - [None, None, None, self.config.decoder_hidden_size * self.config.num_encoder_blocks] - ) - if getattr(self, "batch_norm", None) is not None: - with tf.name_scope(self.batch_norm.name): - self.batch_norm.build([None, None, None, self.config.decoder_hidden_size]) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, None, self.config.decoder_hidden_size]) - if getattr(self, "mlps", None) is not None: - for layer in self.mlps: - with tf.name_scope(layer.name): - layer.build(None) - - -@add_start_docstrings( - """SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes.""", - SEGFORMER_START_DOCSTRING, -) -class TFSegformerForSemanticSegmentation(TFSegformerPreTrainedModel): - def __init__(self, config: SegformerConfig, **kwargs): - super().__init__(config, **kwargs) - self.segformer = TFSegformerMainLayer(config, name="segformer") - self.decode_head = TFSegformerDecodeHead(config, name="decode_head") - - def hf_compute_loss(self, logits, labels): - # upsample logits to the images' original size - # `labels` is of shape (batch_size, height, width) - label_interp_shape = shape_list(labels)[1:] - - upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear") - # compute weighted loss - loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") - - def masked_loss(real, pred): - unmasked_loss = loss_fct(real, pred) - mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype) - masked_loss = unmasked_loss * mask - # Reduction strategy in the similar spirit with - # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210 - reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask) - return tf.reshape(reduced_masked_loss, (1,)) - - return masked_loss(labels, upsampled_logits) - - @unpack_inputs - @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: tf.Tensor, - labels: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - ) -> tuple | TFSemanticSegmenterOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*): - Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels > 1`, a (per-pixel) classification loss is computed - (Cross-Entropy). - - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, TFSegformerForSemanticSegmentation - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") - >>> model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") - - >>> inputs = image_processor(images=image, return_tensors="tf") - >>> outputs = model(**inputs, training=False) - >>> # logits are of shape (batch_size, num_labels, height/4, width/4) - >>> logits = outputs.logits - >>> list(logits.shape) - [1, 150, 128, 128] - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if labels is not None and not self.config.num_labels > 1: - raise ValueError("The number of labels should be greater than one") - - outputs = self.segformer( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=True, # we need the intermediate hidden states - return_dict=return_dict, - ) - - encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] - - logits = self.decode_head(encoder_hidden_states) - - loss = None - if labels is not None: - loss = self.hf_compute_loss(logits=logits, labels=labels) - - # make logits of shape (batch_size, num_labels, height, width) to - # keep them consistent across APIs - logits = tf.transpose(logits, perm=[0, 3, 1, 2]) - - if not return_dict: - if output_hidden_states: - output = (logits,) + outputs[1:] - else: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSemanticSegmenterOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states if output_hidden_states else None, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "segformer", None) is not None: - with tf.name_scope(self.segformer.name): - self.segformer.build(None) - if getattr(self, "decode_head", None) is not None: - with tf.name_scope(self.decode_head.name): - self.decode_head.build(None) - - -__all__ = [ - "TFSegformerDecodeHead", - "TFSegformerForImageClassification", - "TFSegformerForSemanticSegmentation", - "TFSegformerModel", - "TFSegformerPreTrainedModel", -] diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py deleted file mode 100644 index 3614c5d4981b..000000000000 --- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py +++ /dev/null @@ -1,930 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Classes to support Flax Speech-Encoder-Decoder architectures""" - -import os -from typing import Optional, Union - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput -from ...modeling_flax_utils import FlaxPreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from ..auto.configuration_auto import AutoConfig -from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM -from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig" - -SPEECH_ENCODER_DECODER_START_DOCSTRING = r""" - This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech - autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is - loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via - [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder - and should be fine-tuned on a downstream generative task, like summarization. - - The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation - tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation - Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi - Zhou, Wei Li, Peter J. Liu. - - Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech - Translation](https://huggingface.co/papers/2104.06678) it is shown how leveraging large pretrained speech models for speech - translation yields a significant performance improvement. - - After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other - models (see the examples for more information). - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Parameters: - config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r""" - Args: - inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): - Float values of input raw speech waveform or speech features. Values can be obtained by loading a `.flac` - or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* - via the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). - To prepare the array into `inputs`, either the [`Wav2Vec2Processor`] or - [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type - `torch.FloatTensor`. - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be - created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` - and prepending them with the `decoder_start_token_id`. - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.decoder.max_position_embeddings - 1]`. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple. -""" - -SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r""" - Args: - inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): - Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* - or *.wav* audio file into an array of type *list[float]* or a *numpy.ndarray*, *e.g.* via the torchcodec library - (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). - To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or [`Speech2TextProcessor`] should be used - for padding and conversion into a tensor of type *torch.FloatTensor*. - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple. -""" - -SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be - created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` - and prepending them with the `decoder_start_token_id`. - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.decoder.max_position_embeddings - 1]`. - past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a - plain tuple. -""" - - -class FlaxSpeechEncoderDecoderModule(nn.Module): - config: SpeechEncoderDecoderConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - encoder_config = self.config.encoder - decoder_config = self.config.decoder - - # Copied from `modeling_hybrid_clip.py` with modifications. - from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING - - encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class - decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class - - self.encoder = encoder_module(encoder_config, dtype=self.dtype) - self.decoder = decoder_module(decoder_config, dtype=self.dtype) - - # encoder outputs might need to be projected to different dimension for decoder - if ( - self.encoder.config.hidden_size != self.decoder.config.hidden_size - and self.decoder.config.cross_attention_hidden_size is None - ): - self.enc_to_dec_proj = nn.Dense( - self.decoder.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), - dtype=self.dtype, - ) - else: - self.enc_to_dec_proj = None - - def _get_feat_extract_output_lengths( - self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None - ): - """ - Computes the output length of the convolutional layers - """ - - add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter - - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return (input_length - kernel_size) // stride + 1 - - for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride): - input_lengths = _conv_out_length(input_lengths, kernel_size, stride) - - if add_adapter: - for _ in range(self.config.encoder.num_adapter_layers): - input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride) - - return input_lengths - - def _get_encoder_module(self): - return self.encoder - - def _get_projection_module(self): - return self.enc_to_dec_proj - - def _get_decoder_module(self): - return self.decoder - - def __call__( - self, - inputs, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - encoder_outputs=None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - freeze_feature_encoder: bool = False, - ): - if encoder_outputs is None: - encoder_outputs = self.encoder( - inputs, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - freeze_feature_encoder=freeze_feature_encoder, - ) - - encoder_hidden_states = encoder_outputs[0] - - # optionally project encoder_hidden_states - if self.enc_to_dec_proj is not None: - encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) - - # compute correct encoder attention mask - if attention_mask is not None: - encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( - encoder_hidden_states.shape[1], attention_mask - ) - else: - encoder_attention_mask = None - - # flax script modeling_flax_wav2vec2.py - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqLMOutput( - logits=decoder_outputs.logits, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_hidden_states, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING) -class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): - r""" - [`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture - with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one - as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the - encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder. - """ - - config_class = SpeechEncoderDecoderConfig - base_model_prefix: str = "speech_encoder_decoder" - module_class = FlaxSpeechEncoderDecoderModule - - def __init__( - self, - config: SpeechEncoderDecoderConfig, - input_shape: Optional[tuple] = None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - if not _do_init: - raise ValueError( - "`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." - ) - - if config.decoder.cross_attention_hidden_size is not None: - # Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer) - if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: - raise ValueError( - "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" - f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" - f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" - " `config.encoder.hidden_size`." - ) - - # make sure input & output embeddings are not tied - config.tie_word_embeddings = False - module = self.module_class(config=config, dtype=dtype, **kwargs) - - if input_shape is None: - # speech encoders almost always downsample the sequence length dimension - encoder_input_length = 1024 - decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length) - input_shape = ((1, encoder_input_length), (1, decoder_input_length)) - - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - encoder_input_shape, decoder_input_shape = input_shape - - # init input DeviceArrays - inputs = jnp.zeros(encoder_input_shape, dtype="f4") - attention_mask = jnp.ones_like(inputs, dtype="i4") - decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - batch_size, sequence_length = inputs.shape - - decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape - if not decoder_batch_size == batch_size: - raise ValueError( - f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder" - f" and {decoder_batch_size} for decoder." - ) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length) - ) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - inputs, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - def _get_feat_extract_output_lengths( - self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None - ): - return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) - - @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC) - def encode( - self, - inputs: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - freeze_feature_encoder: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import FlaxSpeechEncoderDecoderModel - - >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized - >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( - ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" - ... ) - - >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) - >>> encoder_outputs = model.encode(inputs) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(inputs, dtype="i4") - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, inputs, attention_mask, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(inputs, attention_mask, **kwargs) - - outputs = self.module.apply( - {"params": params or self.params}, - inputs=jnp.array(inputs, dtype="f4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - freeze_feature_encoder=freeze_feature_encoder, - rngs=rngs, - method=_encoder_forward, - ) - - if return_dict: - outputs = FlaxBaseModelOutput( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - return outputs - - @add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import FlaxSpeechEncoderDecoderModel - >>> import jax.numpy as jnp - - >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized - >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( - ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" - ... ) - - >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) - >>> encoder_outputs = model.encode(inputs) - - >>> decoder_start_token_id = model.config.decoder.bos_token_id - >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - params = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxBartAttention module - if past_key_values: - params["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward( - module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs - ): - projection_module = module._get_projection_module() - decoder_module = module._get_decoder_module() - - # optionally project encoder_hidden_states - if projection_module is not None: - encoder_hidden_states = projection_module(encoder_hidden_states) - - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - encoder_hidden_states=encoder_hidden_states, - **kwargs, - ) - - outputs = self.module.apply( - params, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - def __call__( - self, - inputs: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - freeze_feature_encoder: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Examples: - - ```python - >>> from transformers import FlaxSpeechEncoderDecoderModel, AutoTokenizer - - >>> # load a fine-tuned wav2vec2-2-bart model - >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large") - >>> # load output tokenizer - >>> tokenizer_output = AutoTokenizer.from_pretrained("facebook/bart-large") - - >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) - - >>> # use bart's special bos, pad and eos tokens - >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id - >>> model.config.pad_token_id = model.decoder.config.pad_token_id - >>> model.config.eos_token_id = model.decoder.config.eos_token_id - - >>> outputs = model.generate(inputs) - # Assert something? More interesting input? dtype correct? - ``` - """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(inputs, dtype="i4") - - # prepare decoder inputs - if decoder_input_ids is None: - raise ValueError( - "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must" - " be specified as an input argument." - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - if decoder_position_ids is None: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - inputs=jnp.array(inputs, dtype="f4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - freeze_feature_encoder=freeze_feature_encoder, - rngs=rngs, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - decoder_position_ids = jnp.broadcast_to( - jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) - ) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": decoder_position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 - return model_kwargs - - @classmethod - def from_encoder_decoder_pretrained( - cls, - encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, - decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, - *model_args, - **kwargs, - ) -> FlaxPreTrainedModel: - r""" - Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model - checkpoints. - - Params: - encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*): - Information necessary to initiate the encoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`): - Information necessary to initiate the decoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - model_args (remaining positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). - - - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. - - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. - - To update the parent model configuration, do not use a prefix for each configuration parameter. - - Behaves differently depending on whether a `config` is provided or automatically loaded. - - Example: - - ```python - >>> from transformers import FlaxSpeechEncoderDecoderModel - - >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized - >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( - ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" - ... ) - >>> # saving model after fine-tuning - >>> model.save_pretrained("./wav2vec2-2-bart-large") - >>> # load fine-tuned model - >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large") - ```""" - - kwargs_encoder = { - argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") - } - - kwargs_decoder = { - argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") - } - - # remove encoder, decoder kwargs from kwargs - for key in kwargs_encoder: - del kwargs["encoder_" + key] - for key in kwargs_decoder: - del kwargs["decoder_" + key] - - # Load and initialize the encoder and decoder - # The distinction between encoder and decoder at the model level is made - # by the value of the flag `is_decoder` that we need to set correctly. - encoder = kwargs_encoder.pop("model", None) - if encoder is None: - if encoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_encoder: - encoder_config, kwargs_encoder = AutoConfig.from_pretrained( - encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True - ) - if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( - f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " - "from a decoder model. Cross-attention and causal mask are disabled." - ) - encoder_config.is_decoder = False - encoder_config.add_cross_attention = False - - kwargs_encoder["config"] = encoder_config - - encoder = FlaxAutoModel.from_pretrained( - encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder - ) - - decoder = kwargs_decoder.pop("model", None) - if decoder is None: - if decoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_decoder: - decoder_config, kwargs_decoder = AutoConfig.from_pretrained( - decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True - ) - if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: - logger.info( - f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" - f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" - f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." - ) - decoder_config.is_decoder = True - decoder_config.add_cross_attention = True - - kwargs_decoder["config"] = decoder_config - - if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: - logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " - f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " - "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " - "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " - "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" - ) - - decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) - - # instantiate config with corresponding kwargs - dtype = kwargs.pop("dtype", jnp.float32) - config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) - - # make sure input & output word embeddings are not tied - config.tie_word_embeddings = False - - # init model - model = cls(config, dtype=dtype) - model.params["encoder"] = encoder.params - model.params["decoder"] = decoder.params - - return model - - -__all__ = ["FlaxSpeechEncoderDecoderModel"] diff --git a/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py b/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py deleted file mode 100644 index 9286fae776fd..000000000000 --- a/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -import torch -from torch import nn - -from transformers import Speech2TextConfig, Speech2TextForConditionalGeneration - - -def remove_ignore_keys_(state_dict): - ignore_keys = [ - "encoder.version", - "decoder.version", - "model.encoder.version", - "model.decoder.version", - "decoder.output_projection.weight", - "_float_tensor", - "encoder.embed_positions._float_tensor", - "decoder.embed_positions._float_tensor", - ] - for k in ignore_keys: - state_dict.pop(k, None) - - -def rename_keys(s_dict): - keys = list(s_dict.keys()) - for key in keys: - if "transformer_layers" in key: - s_dict[key.replace("transformer_layers", "layers")] = s_dict.pop(key) - elif "subsample" in key: - s_dict[key.replace("subsample", "conv")] = s_dict.pop(key) - - -def make_linear_from_emb(emb): - vocab_size, emb_size = emb.weight.shape - lin_layer = nn.Linear(vocab_size, emb_size, bias=False) - lin_layer.weight.data = emb.weight.data - return lin_layer - - -def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_path): - m2m_100 = torch.load(checkpoint_path, map_location="cpu", weights_only=True) - args = m2m_100["args"] - state_dict = m2m_100["model"] - lm_head_weights = state_dict["decoder.output_projection.weight"] - - remove_ignore_keys_(state_dict) - rename_keys(state_dict) - - vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0] - - tie_embeds = args.share_decoder_input_output_embed - - conv_kernel_sizes = [int(i) for i in args.conv_kernel_sizes.split(",")] - config = Speech2TextConfig( - vocab_size=vocab_size, - max_source_positions=args.max_source_positions, - max_target_positions=args.max_target_positions, - encoder_layers=args.encoder_layers, - decoder_layers=args.decoder_layers, - encoder_attention_heads=args.encoder_attention_heads, - decoder_attention_heads=args.decoder_attention_heads, - encoder_ffn_dim=args.encoder_ffn_embed_dim, - decoder_ffn_dim=args.decoder_ffn_embed_dim, - d_model=args.encoder_embed_dim, - dropout=args.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - activation_function="relu", - num_conv_layers=len(conv_kernel_sizes), - conv_channels=args.conv_channels, - conv_kernel_sizes=conv_kernel_sizes, - input_feat_per_channel=args.input_feat_per_channel, - input_channels=args.input_channels, - tie_word_embeddings=tie_embeds, - num_beams=5, - max_length=200, - use_cache=True, - decoder_start_token_id=2, - early_stopping=True, - ) - - model = Speech2TextForConditionalGeneration(config) - missing, unexpected = model.model.load_state_dict(state_dict, strict=False) - if len(missing) > 0 and not set(missing) <= { - "encoder.embed_positions.weights", - "decoder.embed_positions.weights", - }: - raise ValueError( - "Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing," - f" but all the following weights are missing {missing}" - ) - - if tie_embeds: - model.lm_head = make_linear_from_emb(model.model.decoder.embed_tokens) - else: - model.lm_head.weight.data = lm_head_weights - - model.save_pretrained(pytorch_dump_folder_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument("--fairseq_path", type=str, help="Path to the fairseq model (.pt) file.") - parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") - args = parser.parse_args() - convert_fairseq_s2t_checkpoint_to_tfms(args.fairseq_path, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py deleted file mode 100755 index 402c005b7be7..000000000000 --- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py +++ /dev/null @@ -1,1600 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow Speech2Text model.""" - -from __future__ import annotations - -import random - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation, glu -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPastAndCrossAttentions, - TFSeq2SeqLMOutput, - TFSeq2SeqModelOutput, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - TFSharedEmbeddings, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_speech_to_text import Speech2TextConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "Speech2TextConfig" -_CHECKPOINT_FOR_DOC = "facebook/s2t-small-librispeech-asr" - - -LARGE_NEGATIVE = -1e8 - - -# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right -def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - pad_token_id = tf.cast(pad_token_id, input_ids.dtype) - decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) - start_tokens = tf.fill( - (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) - ) - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz = input_ids_shape[0] - tgt_len = input_ids_shape[1] - mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE - mask_cond = tf.range(shape_list(mask)[-1]) - - mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) - - return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -class TFConv1dSubsampler(keras.layers.Layer): - """ - Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation - via gated linear units (https://huggingface.co/papers/1911.08460) - """ - - def __init__(self, config: Speech2TextConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.num_layers = config.num_conv_layers - self.in_channels = config.input_feat_per_channel * config.input_channels - self.mid_channels = config.conv_channels - self.out_channels = config.d_model - self.kernel_sizes = config.conv_kernel_sizes - - self.conv_layers = [ - keras.layers.Conv1D( - filters=self.mid_channels if i < self.num_layers - 1 else self.out_channels * 2, - kernel_size=k, - strides=2, - name=f"conv_layers.{i}", - ) - for i, k in enumerate(self.kernel_sizes) - ] - - def call(self, input_features: tf.Tensor) -> tf.Tensor: - # TF Conv1D assumes Batch x Time x Channels, same as the input - hidden_states = tf.cast(input_features, tf.float32) - for i, conv in enumerate(self.conv_layers): - # equivalent to `padding=k // 2` on PT's `nn.Conv1d` - pad_len = self.kernel_sizes[i] // 2 - hidden_shapes = shape_list(hidden_states) - hidden_states = tf.concat( - ( - tf.zeros((hidden_shapes[0], pad_len, hidden_shapes[2])), - hidden_states, - tf.zeros((hidden_shapes[0], pad_len, hidden_shapes[2])), - ), - axis=1, - ) - - hidden_states = conv(hidden_states) - hidden_states = glu(hidden_states, axis=2) # GLU over the Channel dimension - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv_layers", None) is not None: - for i, layer in enumerate(self.conv_layers): - with tf.name_scope(layer.name): - layer.build([None, None, self.in_channels] if i == 0 else [None, None, self.mid_channels // 2]) - - -class TFSpeech2TextSinusoidalPositionalEmbedding(keras.layers.Layer): - """This module produces sinusoidal positional embeddings of any length.""" - - def __init__(self, num_positions: int, embedding_dim: int, padding_idx: int | None = None, **kwargs): - super().__init__(**kwargs) - self.offset = 2 - self.embedding_dim = embedding_dim - self.padding_idx = padding_idx - self.embedding_weights = self._get_embedding(num_positions + self.offset, embedding_dim, padding_idx) - - @staticmethod - def _get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None) -> tf.Tensor: - """ - Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the - description in Section 3.5 of "Attention Is All You Need". - """ - half_dim = embedding_dim // 2 - emb = tf.math.log(10000.0) / (half_dim - 1) - emb = tf.math.exp(tf.range(half_dim, dtype=tf.float32) * -emb) - emb = tf.expand_dims(tf.range(num_embeddings, dtype=tf.float32), axis=1) * tf.expand_dims(emb, axis=0) - emb = tf.reshape(tf.concat([tf.math.sin(emb), tf.math.cos(emb)], axis=1), shape=[num_embeddings, -1]) - if embedding_dim % 2 == 1: - # zero pad - emb = tf.concat([emb, tf.zeros(num_embeddings, 1)], axis=1) - if padding_idx is not None: - emb = tf.concat([emb[:padding_idx, :], tf.zeros((1, tf.shape(emb)[1])), emb[padding_idx + 1 :, :]], axis=0) - return emb - - def call(self, input_ids: tf.Tensor, past_key_values_length: int = 0) -> tf.Tensor: - bsz, seq_len = shape_list(input_ids) - # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) - - # Matt: The PyTorch code does a lot of work to cache the embeddings, setting the cached values as a - # model attribute in the forward pass. This is extremely forbidden in TF, which wants forward calls to be - # idempotent. TF doesn't need that caching anyway, since it can just store constants during compilation, - # so we just remove all of that code. - embeddings = self._get_embedding( - self.padding_idx + 1 + seq_len + self.offset + past_key_values_length, self.embedding_dim, self.padding_idx - ) - return tf.reshape(tf.gather(embeddings, tf.reshape(position_ids, (-1,)), axis=0), (bsz, seq_len, -1)) - - @staticmethod - def create_position_ids_from_input_ids( - input_ids: tf.Tensor, padding_idx: int, past_key_values_length: int | None = 0 - ) -> tf.Tensor: - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding - symbols are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - x: tf.Tensor x: - Returns: tf.Tensor - """ - mask = tf.cast(tf.math.not_equal(input_ids, padding_idx), dtype=tf.int32) - incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask - return tf.cast(incremental_indices, dtype=tf.int64) + padding_idx - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Speech2Text -class TFSpeech2TextAttention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -class TFSpeech2TextEncoderLayer(keras.layers.Layer): - def __init__(self, config: Speech2TextConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFSpeech2TextAttention( - self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" - ) - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training: bool = False - ): - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)` - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - training=training, - ) - - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) - - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - return hidden_states, self_attn_weights - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.encoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFSpeech2TextDecoderLayer(keras.layers.Layer): - def __init__(self, config: Speech2TextConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - - self.self_attn = TFSpeech2TextAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - name="self_attn", - is_decoder=True, - ) - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.encoder_attn = TFSpeech2TextAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - name="encoder_attn", - is_decoder=True, - ) - self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") - self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states, - attention_mask: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - cross_attn_layer_head_mask: tf.Tensor | None = None, - past_key_value: tuple[tf.Tensor] | None = None, - training=False, - ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - encoder_hidden_states (`tf.Tensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` - encoder_attention_mask (`tf.Tensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - `(decoder_attention_heads,)` - cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. - `(decoder_attention_heads,)` - past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - training=training, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - training=training, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - return ( - hidden_states, - self_attn_weights, - cross_attn_weights, - present_key_value, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "encoder_attn", None) is not None: - with tf.name_scope(self.encoder_attn.name): - self.encoder_attn.build(None) - if getattr(self, "encoder_attn_layer_norm", None) is not None: - with tf.name_scope(self.encoder_attn_layer_norm.name): - self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.decoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFSpeech2TextPreTrainedModel(TFPreTrainedModel): - config_class = Speech2TextConfig - base_model_prefix = "model" - main_input_name = "input_features" - _keys_to_ignore_on_load_unexpected = [r"encoder.embed_positions.weights"] - - def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor): - """ - Computes the output length of the convolutional layers - """ - for _ in range(self.config.num_conv_layers): - input_lengths = (input_lengths - 1) // 2 + 1 - - return input_lengths - - @property - def input_signature(self): - return { - "input_features": tf.TensorSpec( - (None, None, self.config.input_feat_per_channel * self.config.input_channels), - tf.float32, - name="input_features", - ), - "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), - "decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"), - "decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"), - } - - -SPEECH_TO_TEXT_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`Speech2TextConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -SPEECH_TO_TEXT_INPUTS_DOCSTRING = r""" - Args: - input_features (`tf.Tensor` of shape `(batch_size, sequence_length, feature_size)`): - Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained - by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray or a - `torch.Tensor``, *e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library - (`pip install soundfile`). - To prepare the arrayinto `input_features`, the [`AutoFeatureExtractor`] should be used for extracting - the fbank features, padding and conversion into a tensor of floats. - See [`~Speech2TextFeatureExtractor.__call__`] - attention_mask (`tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`Speech2TextTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - SpeechToText uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If - `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tf.FloatTensor`, *optional*): - hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - of shape `(batch_size, sequence_length, hidden_size)` is a sequence of - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - decoder_inputs_embeds (`tf.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be - input (see `past_key_values`). This is useful if you want more control over how to convert - `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@keras_serializable -class TFSpeech2TextEncoder(keras.layers.Layer): - config_class = Speech2TextConfig - """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`TFSpeech2TextEncoderLayer`]. - - Args: - config: Speech2TextConfig - """ - - def __init__(self, config: Speech2TextConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - - self.dropout = keras.layers.Dropout(config.dropout) - self.layerdrop = config.encoder_layerdrop - - embed_dim = config.d_model - self.padding_idx = config.pad_token_id - self.max_source_positions = config.max_source_positions - self.embed_scale = tf.math.sqrt(float(embed_dim)) if config.scale_embedding else 1.0 - - self.conv = TFConv1dSubsampler(config, name="conv") - - self.embed_positions = TFSpeech2TextSinusoidalPositionalEmbedding( - num_positions=config.max_source_positions, - embedding_dim=embed_dim, - padding_idx=self.padding_idx, - name="embed_positions", - ) - self.layers = [TFSpeech2TextEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - - def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor): - """ - Computes the output length of the convolutional layers - """ - for _ in range(self.config.num_conv_layers): - input_lengths = (input_lengths - 1) // 2 + 1 - - return input_lengths - - def _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask): - # generate creates 3D attention mask, because of the shape of input_features - # convert it to 2D if that's the case - if len(attention_mask.shape) > 2: - attention_mask = attention_mask[:, :, -1] - - subsampled_lengths = self._get_feat_extract_output_lengths(tf.math.reduce_sum(attention_mask, -1)) - bsz = shape_list(attention_mask)[0] - indices = tf.concat( - ( - tf.expand_dims(tf.range(bsz, dtype=attention_mask.dtype), -1), - tf.expand_dims(subsampled_lengths - 1, -1), - ), - axis=-1, - ) - attention_mask = tf.scatter_nd(indices=indices, updates=tf.ones(bsz), shape=[bsz, feature_vector_length]) - attention_mask = tf.cast(tf.reverse(tf.math.cumsum(tf.reverse(attention_mask, [-1]), -1), [-1]), tf.int64) - return attention_mask - - @unpack_inputs - def call( - self, - input_features=None, - attention_mask=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - """ - Args: - input_features (`tf.Tensor` of shape `(batch_size, sequence_length, feature_size)`): - Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a - `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or - the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the fbank features, - padding and conversion into a tensor of floats. See [`~Speech2TextFeatureExtractor.__call__`] - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - if input_features is None: - raise ValueError("You have to specify input_features") - - inputs_embeds = self.conv(input_features) - inputs_embeds = self.embed_scale * inputs_embeds - - # subsample attention mask if necessary - if attention_mask is not None: - attention_mask = self._get_feature_vector_attention_mask(tf.shape(inputs_embeds)[1], attention_mask) - padding_mask = tf.cast(tf.math.not_equal(attention_mask, 1), tf.int64) - else: - padding_mask = tf.zeros(tf.shape(inputs_embeds)[:-1], dtype=tf.int64) - - embed_pos = self.embed_positions(padding_mask) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.dropout(hidden_states, training=training) - - # check attention mask and invert - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - tf.debugging.assert_equal( - shape_list(head_mask)[0], - len(self.layers), - message=( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(head_mask)[0]}." - ), - ) - - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): # skip the layer - continue - - hidden_states, attn = encoder_layer( - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - training=training, - ) - - if output_attentions: - all_attentions += (attn,) - - hidden_states = self.layer_norm(hidden_states) - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build(None) - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFSpeech2TextDecoder(keras.layers.Layer): - config_class = Speech2TextConfig - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFSpeech2TextDecoderLayer`] - - Args: - config: Speech2TextConfig - """ - - def __init__(self, config: Speech2TextConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layerdrop = config.decoder_layerdrop - self.padding_idx = config.pad_token_id - self.max_target_positions = config.max_target_positions - self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - - self.embed_tokens = TFSharedEmbeddings(config.vocab_size, config.d_model, name="embed_tokens") - - self.embed_positions = TFSpeech2TextSinusoidalPositionalEmbedding( - num_positions=config.max_target_positions, - embedding_dim=config.d_model, - padding_idx=self.padding_idx, - name="embed_positions", - ) - - self.layers = [TFSpeech2TextDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - - self.dropout = keras.layers.Dropout(config.dropout) - - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - @unpack_inputs - def call( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`Speech2TextTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up - decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - # past_key_values_length - past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.vocab_size) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - else: - inputs_embeds = inputs_embeds - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) - else: - combined_attention_mask = _expand_mask( - tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] - ) - - if attention_mask is not None: - combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) - - # embed positions - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) - - hidden_states = inputs_embeds + positions - hidden_states = self.dropout(hidden_states, training=training) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None - - # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired - for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: - if attn_mask is not None: - tf.debugging.assert_equal( - shape_list(attn_mask)[0], - len(self.layers), - message=( - f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(attn_mask)[0]}." - ), - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - - hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, - ) - - if use_cache: - next_decoder_cache += (present_key_value,) - - if output_attentions: - all_self_attns += (layer_self_attn,) - - if encoder_hidden_states is not None: - all_cross_attns += (layer_cross_attn,) - - hidden_states = self.layer_norm(hidden_states) - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - return hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attns - else: - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_tokens", None) is not None: - with tf.name_scope(self.embed_tokens.name): - self.embed_tokens.build(None) - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFSpeech2TextMainLayer(keras.layers.Layer): - config_class = Speech2TextConfig - - def __init__(self, config: Speech2TextConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - - self.encoder = TFSpeech2TextEncoder(config, name="encoder") - self.decoder = TFSpeech2TextDecoder(config, name="decoder") - - def get_input_embeddings(self): - return self.decoder.embed_tokens - - def set_input_embeddings(self, new_embeddings): - self.decoder.embed_tokens = new_embeddings - - @unpack_inputs - def call( - self, - input_features=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - encoder_outputs=None, - past_key_values=None, - decoder_inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - **kwargs, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_features=input_features, - attention_mask=attention_mask, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): - encoder_outputs = TFBaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not return_dict and not isinstance(encoder_outputs, tuple): - encoder_outputs = encoder_outputs.to_tuple() - - # downsample encoder attention mask - if attention_mask is not None: - encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( - tf.shape(encoder_outputs[0])[1], attention_mask - ) - else: - encoder_attention_mask = None - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=encoder_attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return TFSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - "The bare Speech2Text Model outputting raw hidden-states without any specific head on top.", - SPEECH_TO_TEXT_START_DOCSTRING, -) -class TFSpeech2TextModel(TFSpeech2TextPreTrainedModel): - def __init__(self, config: Speech2TextConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.model = TFSpeech2TextMainLayer(config, name="model") - - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSeq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_features: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, - ) -> tuple | TFSeq2SeqModelOutput: - outputs = self.model( - input_features=input_features, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqModelOutput( - last_hidden_state=output.last_hidden_state, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -@add_start_docstrings( - "The Speech2Text Model with a language modeling head. Can be used for summarization.", - SPEECH_TO_TEXT_START_DOCSTRING, -) -class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config: Speech2TextConfig): - super().__init__(config) - self.model = TFSpeech2TextMainLayer(config, name="model") - self.lm_head = keras.layers.Dense(self.config.vocab_size, use_bias=False, name="lm_head") - # TODO (Joao): investigate why Speech2Text has numerical issues in XLA generate - self.supports_xla_generation = False - self.config = config - - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - - def resize_token_embeddings(self, new_num_tokens: int) -> tf.Variable: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - return new_embeddings - - @unpack_inputs - @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_features: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - labels: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs, - ) -> tuple | TFSeq2SeqLMOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> import tensorflow as tf - >>> from transformers import Speech2TextProcessor, TFSpeech2TextForConditionalGeneration - >>> from datasets import load_dataset - - >>> model = TFSpeech2TextForConditionalGeneration.from_pretrained( - ... "facebook/s2t-small-librispeech-asr", from_pt=True - ... ) - >>> processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr") - - - >>> def map_to_array(example): - ... example["speech"] = example["audio"]["array"] - ... return example - - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.map(map_to_array) - >>> ds.set_format(type="tf") - - >>> input_features = processor( - ... ds["speech"][0], sampling_rate=16000, return_tensors="tf" - ... ).input_features # Batch size 1 - >>> generated_ids = model.generate(input_features) - - >>> transcription = processor.batch_decode(generated_ids) - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if labels is not None: - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.model( - input_features=input_features, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - lm_logits = self.lm_head(outputs[0]) - masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return TFSeq2SeqLMOutput( - loss=masked_lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - return { - "input_features": None, # needs to be passed to make Keras.layer.__call__ happy - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build([None, None, self.config.d_model]) - - def tf_to_pt_weight_rename(self, tf_weight): - if tf_weight == "lm_head.weight": - return tf_weight, "model.decoder.embed_tokens.weight" - else: - return (tf_weight,) - - -__all__ = ["TFSpeech2TextForConditionalGeneration", "TFSpeech2TextModel", "TFSpeech2TextPreTrainedModel"] diff --git a/src/transformers/models/swiftformer/modeling_tf_swiftformer.py b/src/transformers/models/swiftformer/modeling_tf_swiftformer.py deleted file mode 100644 index 612c2406a1d0..000000000000 --- a/src/transformers/models/swiftformer/modeling_tf_swiftformer.py +++ /dev/null @@ -1,866 +0,0 @@ -# coding=utf-8 -# Copyright 2024 MBZUAI and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow SwiftFormer model.""" - -import collections.abc -from typing import Optional, Union - -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithNoAttention, - TFImageClassifierOutputWithNoAttention, -) -from ...modeling_tf_utils import TFPreTrainedModel, keras, keras_serializable, unpack_inputs -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_swiftformer import SwiftFormerConfig - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "SwiftFormerConfig" - -# Base docstring -_CHECKPOINT_FOR_DOC = "MBZUAI/swiftformer-xs" -_EXPECTED_OUTPUT_SHAPE = [1, 220, 7, 7] - -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "MBZUAI/swiftformer-xs" -_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" - - -class TFSwiftFormerPatchEmbeddingSequential(keras.layers.Layer): - """ - The sequential component of the patch embedding layer. - - Input: tensor of shape `[batch_size, in_channels, height, width]` - - Output: tensor of shape `[batch_size, out_channels, height/4, width/4]` - """ - - def __init__(self, config: SwiftFormerConfig, **kwargs): - super().__init__(**kwargs) - self.out_chs = config.embed_dims[0] - - self.zero_padding = keras.layers.ZeroPadding2D(padding=(1, 1)) - self.conv1 = keras.layers.Conv2D(self.out_chs // 2, kernel_size=3, strides=2, name="0") - self.batch_norm1 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="1") - self.conv2 = keras.layers.Conv2D(self.out_chs, kernel_size=3, strides=2, name="3") - self.batch_norm2 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="4") - self.config = config - - def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: - x = self.zero_padding(x) - x = self.conv1(x) - x = self.batch_norm1(x, training=training) - x = get_tf_activation("relu")(x) - x = self.zero_padding(x) - x = self.conv2(x) - x = self.batch_norm2(x, training=training) - x = get_tf_activation("relu")(x) - return x - - def build(self, input_shape=None): - if self.built: - return - if getattr(self, "conv1", None) is not None: - with tf.name_scope(self.conv1.name): - self.conv1.build(self.config.num_channels) - if getattr(self, "batch_norm1", None) is not None: - with tf.name_scope(self.batch_norm1.name): - self.batch_norm1.build((None, None, None, self.out_chs // 2)) - if getattr(self, "conv2", None) is not None: - with tf.name_scope(self.conv2.name): - self.conv2.build((None, None, None, self.out_chs // 2)) - if getattr(self, "batch_norm2", None) is not None: - with tf.name_scope(self.batch_norm2.name): - self.batch_norm2.build((None, None, None, self.out_chs)) - self.built = True - - -class TFSwiftFormerPatchEmbedding(keras.layers.Layer): - """ - Patch Embedding Layer constructed of two 2D convolutional layers. - - Input: tensor of shape `[batch_size, in_channels, height, width]` - - Output: tensor of shape `[batch_size, out_channels, height/4, width/4]` - """ - - def __init__(self, config: SwiftFormerConfig, **kwargs): - super().__init__(**kwargs) - self.patch_embedding = TFSwiftFormerPatchEmbeddingSequential(config, name="patch_embedding") - - def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: - return self.patch_embedding(x, training=training) - - def build(self, input_shape=None): - if self.built: - return - if getattr(self, "patch_embedding", None) is not None: - with tf.name_scope(self.patch_embedding.name): - self.patch_embedding.build(None) - self.built = True - - -class TFSwiftFormerDropPath(keras.layers.Layer): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, config: SwiftFormerConfig, **kwargs) -> None: - super().__init__(**kwargs) - raise NotImplementedError("Drop path is not implemented in TF port") - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - raise NotImplementedError("Drop path is not implemented in TF port") - - -class TFSwiftFormerEmbeddings(keras.layers.Layer): - """ - Embeddings layer consisting of a single 2D convolutional and batch normalization layer. - - Input: tensor of shape `[batch_size, channels, height, width]` - - Output: tensor of shape `[batch_size, channels, height/stride, width/stride]` - """ - - def __init__(self, config: SwiftFormerConfig, index: int, **kwargs): - super().__init__(**kwargs) - - patch_size = config.down_patch_size - stride = config.down_stride - padding = config.down_pad - embed_dims = config.embed_dims - - self.in_chans = embed_dims[index] - self.embed_dim = embed_dims[index + 1] - - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride) - padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding) - - self.pad = keras.layers.ZeroPadding2D(padding=padding) - self.proj = keras.layers.Conv2D(self.embed_dim, kernel_size=patch_size, strides=stride, name="proj") - self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") - - def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: - x = self.pad(x) - x = self.proj(x) - x = self.norm(x, training=training) - return x - - def build(self, input_shape=None): - if self.built: - return - if getattr(self, "proj", None) is not None: - with tf.name_scope(self.proj.name): - self.proj.build(self.in_chans) - if getattr(self, "norm", None) is not None: - with tf.name_scope(self.norm.name): - self.norm.build((None, None, None, self.embed_dim)) - self.built = True - - -class TFSwiftFormerConvEncoder(keras.layers.Layer): - """ - `SwiftFormerConvEncoder` with 3*3 and 1*1 convolutions. - - Input: tensor of shape `[batch_size, channels, height, width]` - - Output: tensor of shape `[batch_size, channels, height, width]` - """ - - def __init__(self, config: SwiftFormerConfig, dim: int, **kwargs): - super().__init__(**kwargs) - hidden_dim = int(config.mlp_ratio * dim) - - self.dim = dim - self.pad = keras.layers.ZeroPadding2D(padding=(1, 1)) - self.depth_wise_conv = keras.layers.Conv2D(dim, kernel_size=3, groups=dim, name="depth_wise_conv") - self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") - self.point_wise_conv1 = keras.layers.Conv2D(hidden_dim, kernel_size=1, name="point_wise_conv1") - self.act = get_tf_activation("gelu") - self.point_wise_conv2 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv2") - self.drop_path = keras.layers.Dropout(name="drop_path", rate=config.drop_conv_encoder_rate) - self.hidden_dim = int(config.mlp_ratio * self.dim) - - def build(self, input_shape=None): - if self.built: - return - self.layer_scale = self.add_weight( - name="layer_scale", - shape=self.dim, - initializer="ones", - trainable=True, - ) - - if getattr(self, "depth_wise_conv", None) is not None: - with tf.name_scope(self.depth_wise_conv.name): - self.depth_wise_conv.build(self.dim) - if getattr(self, "norm", None) is not None: - with tf.name_scope(self.norm.name): - self.norm.build((None, None, None, self.dim)) - if getattr(self, "point_wise_conv1", None) is not None: - with tf.name_scope(self.point_wise_conv1.name): - self.point_wise_conv1.build(self.dim) - if getattr(self, "point_wise_conv2", None) is not None: - with tf.name_scope(self.point_wise_conv2.name): - self.point_wise_conv2.build(self.hidden_dim) - if getattr(self, "drop_path", None) is not None: - with tf.name_scope(self.drop_path.name): - self.drop_path.build(None) - self.built = True - - def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: - input = x - x = self.pad(x) - x = self.depth_wise_conv(x) - x = self.norm(x, training=training) - x = self.point_wise_conv1(x) - x = self.act(x) - x = self.point_wise_conv2(x) - x = input + self.drop_path(self.layer_scale * x) - return x - - -class TFSwiftFormerMlp(keras.layers.Layer): - """ - MLP layer with 1*1 convolutions. - - Input: tensor of shape `[batch_size, channels, height, width]` - - Output: tensor of shape `[batch_size, channels, height, width]` - """ - - def __init__(self, config: SwiftFormerConfig, in_features: int, **kwargs): - super().__init__(**kwargs) - - hidden_features = int(in_features * config.mlp_ratio) - self.norm1 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm1") - self.fc1 = keras.layers.Conv2D(hidden_features, 1, name="fc1") - act_layer = get_tf_activation(config.hidden_act) - self.act = act_layer - self.fc2 = keras.layers.Conv2D(in_features, 1, name="fc2") - self.drop = keras.layers.Dropout(rate=config.drop_mlp_rate) - self.hidden_features = hidden_features - self.in_features = in_features - - def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: - x = self.norm1(x, training=training) - x = self.fc1(x) - x = self.act(x) - x = self.drop(x, training=training) - x = self.fc2(x) - x = self.drop(x, training=training) - return x - - def build(self, input_shape=None): - if self.built: - return - if getattr(self, "norm1", None) is not None: - with tf.name_scope(self.norm1.name): - self.norm1.build((None, None, None, self.in_features)) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build((None, None, None, self.in_features)) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build((None, None, None, self.hidden_features)) - self.built = True - - -class TFSwiftFormerEfficientAdditiveAttention(keras.layers.Layer): - """ - Efficient Additive Attention module for SwiftFormer. - - Input: tensor of shape `[batch_size, channels, height, width]` - - Output: tensor of shape `[batch_size, channels, height, width]` - """ - - def __init__(self, config: SwiftFormerConfig, dim: int = 512, **kwargs): - super().__init__(**kwargs) - - self.dim = dim - - self.to_query = keras.layers.Dense(dim, name="to_query") - self.to_key = keras.layers.Dense(dim, name="to_key") - - self.scale_factor = dim**-0.5 - self.proj = keras.layers.Dense(dim, name="proj") - self.final = keras.layers.Dense(dim, name="final") - - def build(self, input_shape=None): - if self.built: - return - self.w_g = self.add_weight( - name="w_g", - shape=(self.dim, 1), - initializer=keras.initializers.RandomNormal(mean=0, stddev=1), - trainable=True, - ) - - if getattr(self, "to_query", None) is not None: - with tf.name_scope(self.to_query.name): - self.to_query.build(self.dim) - if getattr(self, "to_key", None) is not None: - with tf.name_scope(self.to_key.name): - self.to_key.build(self.dim) - if getattr(self, "proj", None) is not None: - with tf.name_scope(self.proj.name): - self.proj.build(self.dim) - if getattr(self, "final", None) is not None: - with tf.name_scope(self.final.name): - self.final.build(self.dim) - self.built = True - - def call(self, x: tf.Tensor) -> tf.Tensor: - query = self.to_query(x) - key = self.to_key(x) - - query = tf.math.l2_normalize(query, dim=-1) - key = tf.math.l2_normalize(key, dim=-1) - - query_weight = query @ self.w_g - scaled_query_weight = query_weight * self.scale_factor - scaled_query_weight = tf.nn.softmax(scaled_query_weight, axis=-1) - - global_queries = tf.math.reduce_sum(scaled_query_weight * query, axis=1) - global_queries = tf.tile(tf.expand_dims(global_queries, 1), (1, key.shape[1], 1)) - - out = self.proj(global_queries * key) + query - out = self.final(out) - - return out - - -class TFSwiftFormerLocalRepresentation(keras.layers.Layer): - """ - Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions. - - Input: tensor of shape `[batch_size, channels, height, width]` - - Output: tensor of shape `[batch_size, channels, height, width]` - """ - - def __init__(self, config: SwiftFormerConfig, dim: int, **kwargs): - super().__init__(**kwargs) - - self.dim = dim - - self.pad = keras.layers.ZeroPadding2D(padding=(1, 1)) - self.depth_wise_conv = keras.layers.Conv2D(dim, kernel_size=3, groups=dim, name="depth_wise_conv") - self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") - self.point_wise_conv1 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv1") - self.act = get_tf_activation("gelu") - self.point_wise_conv2 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv2") - self.drop_path = keras.layers.Identity(name="drop_path") - - def build(self, input_shape=None): - if self.built: - return - self.layer_scale = self.add_weight( - name="layer_scale", - shape=(self.dim), - initializer="ones", - trainable=True, - ) - if getattr(self, "depth_wise_conv", None) is not None: - with tf.name_scope(self.depth_wise_conv.name): - self.depth_wise_conv.build((None, None, None, self.dim)) - if getattr(self, "norm", None) is not None: - with tf.name_scope(self.norm.name): - self.norm.build((None, None, None, self.dim)) - if getattr(self, "point_wise_conv1", None) is not None: - with tf.name_scope(self.point_wise_conv1.name): - self.point_wise_conv1.build(self.dim) - if getattr(self, "point_wise_conv2", None) is not None: - with tf.name_scope(self.point_wise_conv2.name): - self.point_wise_conv2.build(self.dim) - if getattr(self, "drop_path", None) is not None: - with tf.name_scope(self.drop_path.name): - self.drop_path.build(None) - self.built = True - - def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: - input = x - x = self.pad(x) - x = self.depth_wise_conv(x) - x = self.norm(x, training=training) - x = self.point_wise_conv1(x) - x = self.act(x) - x = self.point_wise_conv2(x) - x = input + self.drop_path(self.layer_scale * x, training=training) - return x - - -class TFSwiftFormerEncoderBlock(keras.layers.Layer): - """ - SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2) - SwiftFormerEfficientAdditiveAttention, and (3) MLP block. - - Input: tensor of shape `[batch_size, channels, height, width]` - - Output: tensor of shape `[batch_size, channels,height, width]` - """ - - def __init__(self, config: SwiftFormerConfig, dim: int, drop_path: float = 0.0, **kwargs): - super().__init__(**kwargs) - - layer_scale_init_value = config.layer_scale_init_value - use_layer_scale = config.use_layer_scale - - self.local_representation = TFSwiftFormerLocalRepresentation(config, dim=dim, name="local_representation") - self.attn = TFSwiftFormerEfficientAdditiveAttention(config, dim=dim, name="attn") - self.linear = TFSwiftFormerMlp(config, in_features=dim, name="linear") - self.drop_path = TFSwiftFormerDropPath(config) if drop_path > 0.0 else keras.layers.Identity() - self.use_layer_scale = use_layer_scale - if use_layer_scale: - self.dim = dim - self.layer_scale_init_value = layer_scale_init_value - - def build(self, input_shape=None): - if self.built: - return - self.layer_scale_1 = self.add_weight( - name="layer_scale_1", - shape=self.dim, - initializer=keras.initializers.constant(self.layer_scale_init_value), - trainable=True, - ) - self.layer_scale_2 = self.add_weight( - name="layer_scale_2", - shape=self.dim, - initializer=keras.initializers.constant(self.layer_scale_init_value), - trainable=True, - ) - - if getattr(self, "local_representation", None) is not None: - with tf.name_scope(self.local_representation.name): - self.local_representation.build(None) - if getattr(self, "attn", None) is not None: - with tf.name_scope(self.attn.name): - self.attn.build(None) - if getattr(self, "linear", None) is not None: - with tf.name_scope(self.linear.name): - self.linear.build(None) - self.built = True - - def call(self, x: tf.Tensor, training: bool = False): - x = self.local_representation(x, training=training) - batch_size, height, width, channels = x.shape - - res = tf.reshape(x, [-1, height * width, channels]) - res = self.attn(res) - res = tf.reshape(res, [-1, height, width, channels]) - if self.use_layer_scale: - x = x + self.drop_path(self.layer_scale_1 * res, training=training) - x = x + self.drop_path(self.layer_scale_2 * self.linear(x), training=training) - else: - x = x + self.drop_path(res, training=training) - x = x + self.drop_path(self.linear(x), training=training) - return x - - -class TFSwiftFormerStage(keras.layers.Layer): - """ - A Swiftformer stage consisting of a series of `SwiftFormerConvEncoder` blocks and a final - `SwiftFormerEncoderBlock`. - - Input: tensor in shape `[batch_size, channels, height, width]` - - Output: tensor in shape `[batch_size, channels, height, width]` - """ - - def __init__(self, config: SwiftFormerConfig, index: int, **kwargs) -> None: - super().__init__(**kwargs) - - layer_depths = config.depths - dim = config.embed_dims[index] - depth = layer_depths[index] - - self.blocks = [] - for block_idx in range(depth): - block_dpr = config.drop_path_rate * (block_idx + sum(layer_depths[:index])) / (sum(layer_depths) - 1) - - if depth - block_idx <= 1: - self.blocks.append( - TFSwiftFormerEncoderBlock(config, dim=dim, drop_path=block_dpr, name=f"blocks_._{block_idx}") - ) - else: - self.blocks.append(TFSwiftFormerConvEncoder(config, dim=dim, name=f"blocks_._{block_idx}")) - - def call(self, input: tf.Tensor, training: bool = False) -> tf.Tensor: - for i, block in enumerate(self.blocks): - input = block(input, training=training) - return input - - def build(self, input_shape=None): - for layer in self.blocks: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFSwiftFormerEncoder(keras.layers.Layer): - def __init__(self, config: SwiftFormerConfig, **kwargs) -> None: - super().__init__(**kwargs) - self.config = config - - embed_dims = config.embed_dims - downsamples = config.downsamples - layer_depths = config.depths - - # Transformer model - self.network = [] - name_i = 0 - for i in range(len(layer_depths)): - stage = TFSwiftFormerStage(config, index=i, name=f"network_._{name_i}") - self.network.append(stage) - name_i += 1 - if i >= len(layer_depths) - 1: - break - if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: - # downsampling between two stages - self.network.append(TFSwiftFormerEmbeddings(config, index=i, name=f"network_._{name_i}")) - name_i += 1 - - self.gradient_checkpointing = False - - def call( - self, - hidden_states: tf.Tensor, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[tuple, TFBaseModelOutputWithNoAttention]: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - all_hidden_states = (hidden_states,) if output_hidden_states else None - - for i, block in enumerate(self.network): - hidden_states = block(hidden_states, training=training) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2]) - if all_hidden_states: - all_hidden_states = tuple(tf.transpose(s, perm=[0, 3, 1, 2]) for s in all_hidden_states) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) - - return TFBaseModelOutputWithNoAttention( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - ) - - def build(self, input_shape=None): - for layer in self.network: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFSwiftFormerPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = SwiftFormerConfig - base_model_prefix = "swiftformer" - main_input_name = "pixel_values" - - -TFSWIFTFORMER_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TF 2.0 models accepts two formats as inputs: - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional arguments. - This second option is useful when using [`keras.Model.fit`] method which currently requires having all the - tensors in the first argument of the model call function: `model(inputs)`. - If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the - first positional argument : - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - - - Parameters: - config ([`SwiftFormerConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -TFSWIFTFORMER_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] - for details. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - training (`bool`, *optional*, defaults to `False`): - Whether or not to run the model in training mode. -""" - - -@keras_serializable -class TFSwiftFormerMainLayer(keras.layers.Layer): - config_class = SwiftFormerConfig - - def __init__(self, config: SwiftFormerConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - - self.patch_embed = TFSwiftFormerPatchEmbedding(config, name="patch_embed") - self.encoder = TFSwiftFormerEncoder(config, name="encoder") - - @unpack_inputs - def call( - self, - pixel_values: Optional[tf.Tensor] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[tuple, TFBaseModelOutputWithNoAttention]: - r""" """ - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # TF 2.0 image layers can't use NCHW format when running on CPU. - # We transpose to NHWC format and then transpose back after the full forward pass. - # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) - pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1]) - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - embedding_output = self.patch_embed(pixel_values, training=training) - encoder_outputs = self.encoder( - embedding_output, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return tuple(v for v in encoder_outputs if v is not None) - - return TFBaseModelOutputWithNoAttention( - last_hidden_state=encoder_outputs.last_hidden_state, - hidden_states=encoder_outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - if getattr(self, "patch_embed", None) is not None: - with tf.name_scope(self.patch_embed.name): - self.patch_embed.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - self.built = True - - -@add_start_docstrings( - "The bare TFSwiftFormer Model transformer outputting raw hidden-states without any specific head on top.", - TFSWIFTFORMER_START_DOCSTRING, -) -class TFSwiftFormerModel(TFSwiftFormerPreTrainedModel): - def __init__(self, config: SwiftFormerConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.swiftformer = TFSwiftFormerMainLayer(config, name="swiftformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING) - def call( - self, - pixel_values: Optional[tf.Tensor] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[TFBaseModelOutputWithNoAttention, tuple[tf.Tensor]]: - outputs = self.swiftformer( - pixel_values=pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - if getattr(self, "swiftformer", None) is not None: - with tf.name_scope(self.swiftformer.name): - self.swiftformer.build(None) - self.built = True - - -@add_start_docstrings( - """ - TFSwiftFormer Model transformer with an image classification head on top (e.g. for ImageNet). - """, - TFSWIFTFORMER_START_DOCSTRING, -) -class TFSwiftFormerForImageClassification(TFSwiftFormerPreTrainedModel): - def __init__(self, config: SwiftFormerConfig, **kwargs) -> None: - super().__init__(config, **kwargs) - - self.num_labels = config.num_labels - self.swiftformer = TFSwiftFormerMainLayer(config, name="swiftformer") - - # Classifier head - self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") - self.head = ( - keras.layers.Dense(self.num_labels, name="head") - if self.num_labels > 0 - else keras.layers.Identity(name="head") - ) - self.dist_head = ( - keras.layers.Dense(self.num_labels, name="dist_head") - if self.num_labels > 0 - else keras.layers.Identity(name="dist_head") - ) - - def hf_compute_loss(self, labels, logits): - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == tf.int64 or labels.dtype == tf.int32): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = keras.losses.MSE - if self.num_labels == 1: - loss = loss_fct(labels.squeeze(), logits.squeeze()) - else: - loss = loss_fct(labels, logits) - elif self.config.problem_type == "single_label_classification": - loss_fct = keras.losses.SparseCategoricalCrossentropy( - from_logits=True, reduction=keras.losses.Reduction.NONE - ) - loss = loss_fct(labels, logits) - elif self.config.problem_type == "multi_label_classification": - loss_fct = keras.losses.SparseCategoricalCrossentropy( - from_logits=True, - reduction=keras.losses.Reduction.NONE, - ) - loss = loss_fct(labels, logits) - else: - loss = None - - return loss - - @unpack_inputs - @add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING) - def call( - self, - pixel_values: Optional[tf.Tensor] = None, - labels: Optional[tf.Tensor] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: bool = False, - ) -> Union[tuple, TFImageClassifierOutputWithNoAttention]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # run base model - outputs = self.swiftformer( - pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs.last_hidden_state if return_dict else outputs[0] - sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1]) - - # run classification head - sequence_output = self.norm(sequence_output, training=training) - sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2]) - _, num_channels, height, width = sequence_output.shape - sequence_output = tf.reshape(sequence_output, [-1, num_channels, height * width]) - sequence_output = tf.reduce_mean(sequence_output, axis=-1) - cls_out = self.head(sequence_output) - distillation_out = self.dist_head(sequence_output) - logits = (cls_out + distillation_out) / 2 - - # calculate loss - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFImageClassifierOutputWithNoAttention( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - if getattr(self, "swiftformer", None) is not None: - with tf.name_scope(self.swiftformer.name): - self.swiftformer.build(None) - if getattr(self, "norm", None) is not None: - with tf.name_scope(self.norm.name): - self.norm.build((None, None, None, self.config.embed_dims[-1])) - if getattr(self, "head", None) is not None: - with tf.name_scope(self.head.name): - self.head.build(self.config.embed_dims[-1]) - if getattr(self, "dist_head", None) is not None: - with tf.name_scope(self.dist_head.name): - self.dist_head.build(self.config.embed_dims[-1]) - self.built = True - - -__all__ = ["TFSwiftFormerForImageClassification", "TFSwiftFormerModel", "TFSwiftFormerPreTrainedModel"] diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py deleted file mode 100644 index 7fa54e958046..000000000000 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ /dev/null @@ -1,1639 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 Swin Transformer model.""" - -from __future__ import annotations - -import collections.abc -import math -import warnings -from collections.abc import Iterable -from dataclasses import dataclass -from functools import partial -from typing import Any, Callable - -import tensorflow as tf - -from ...activations_tf import ACT2FN -from ...modeling_tf_utils import ( - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_swin import SwinConfig - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "SwinConfig" - -# Base docstring -_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224" -_EXPECTED_OUTPUT_SHAPE = [1, 49, 768] - -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "microsoft/swin-tiny-patch4-window7-224" -_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" - - -# drop_path, TFSwinPatchEmbeddings, TFSwinPatchMerging and TFSwinDropPath are tensorflow -# implementations of PyTorch functionalities in the timm library. - - -@dataclass -class TFSwinEncoderOutput(ModelOutput): - """ - Swin encoder's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape - `(batch_size, hidden_size, height, width)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to - include the spatial dimensions. - """ - - last_hidden_state: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - reshaped_hidden_states: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFSwinModelOutput(ModelOutput): - """ - Swin model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): - Average pooling of the last layer hidden-state. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape - `(batch_size, hidden_size, height, width)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to - include the spatial dimensions. - """ - - last_hidden_state: tf.Tensor | None = None - pooler_output: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - reshaped_hidden_states: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFSwinMaskedImageModelingOutput(ModelOutput): - """ - Swin masked image model outputs. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): - Masked image modeling (MLM) loss. - reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Reconstructed pixel values. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape - `(batch_size, hidden_size, height, width)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to - include the spatial dimensions. - """ - - loss: tf.Tensor | None = None - reconstruction: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - reshaped_hidden_states: tuple[tf.Tensor, ...] | None = None - - @property - def logits(self): - warnings.warn( - "logits attribute is deprecated and will be removed in version 5 of Transformers." - " Please use the reconstruction attribute to retrieve the final output instead.", - FutureWarning, - ) - return self.reconstruction - - -@dataclass -class TFSwinImageClassifierOutput(ModelOutput): - """ - Swin outputs for image classification. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape - `(batch_size, hidden_size, height, width)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to - include the spatial dimensions. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - reshaped_hidden_states: tuple[tf.Tensor, ...] | None = None - - -def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor: - """ - Partitions the given input into windows. - """ - batch_size, height, width, num_channels = shape_list(input_feature) - input_feature = tf.reshape( - input_feature, - (batch_size, height // window_size, window_size, width // window_size, window_size, num_channels), - ) - windows = tf.transpose(input_feature, (0, 1, 3, 2, 4, 5)) - windows = tf.reshape(windows, (-1, window_size, window_size, num_channels)) - return windows - - -def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int) -> tf.Tensor: - """ - Merges windows to produce higher resolution features. - """ - x = tf.shape(windows)[0] - y = tf.cast(height * width / (window_size * window_size), tf.int32) - batch_size = tf.math.floordiv(x, y) - windows = tf.reshape( - windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1) - ) - windows = tf.transpose(windows, (0, 1, 3, 2, 4, 5)) - windows = tf.reshape(windows, (batch_size, height, width, -1)) - return windows - - -def drop_path( - input: tf.Tensor, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True -) -> tf.Tensor: - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - """ - if drop_prob == 0.0 or not training: - return input - keep_prob = 1 - drop_prob - input_shape = shape_list(input) - ndim = len(input_shape) - shape = [input_shape[0]] + [1] * (ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = tf.random.uniform(shape) - random_tensor = tf.where(random_tensor <= keep_prob, 1.0, 0.0) - if keep_prob > 0.0 and scale_by_keep: - random_tensor /= keep_prob - return input * random_tensor - - -class TFSwinEmbeddings(keras.layers.Layer): - """ - Construct the patch and position embeddings. Optionally, also the mask token. - """ - - def __init__(self, config: SwinConfig, use_mask_token: bool = False, **kwargs) -> None: - super().__init__(**kwargs) - self.patch_embeddings = TFSwinPatchEmbeddings(config, name="patch_embeddings") - self.num_patches = self.patch_embeddings.num_patches - self.patch_grid = self.patch_embeddings.grid_size - self.embed_dim = config.embed_dim - self.use_mask_token = use_mask_token - self.use_absolute_embeddings = config.use_absolute_embeddings - - self.norm = keras.layers.LayerNormalization(name="norm", epsilon=1e-5) - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout") - self.config = config - - def build(self, input_shape: tf.TensorShape) -> None: - if self.use_mask_token: - self.mask_token = self.add_weight(shape=(1, 1, self.embed_dim), initializer="zeros", name="mask_token") - else: - self.mask_token = None - - if self.use_absolute_embeddings: - self.position_embeddings = self.add_weight( - (1, self.num_patches + 1, self.embed_dim), initializer="zeros", name="positional_embeddings" - ) - else: - self.position_embeddings = None - - if self.built: - return - self.built = True - if getattr(self, "patch_embeddings", None) is not None: - with tf.name_scope(self.patch_embeddings.name): - self.patch_embeddings.build(None) - if getattr(self, "norm", None) is not None: - with tf.name_scope(self.norm.name): - self.norm.build([None, None, self.config.embed_dim]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - - def call( - self, pixel_values: tf.Tensor, bool_masked_pos: bool | None = None, training: bool = False - ) -> tuple[tf.Tensor, tuple[int, int]]: - embeddings, output_dimensions = self.patch_embeddings(pixel_values, training=training) - embeddings = self.norm(embeddings, training=training) - batch_size, seq_len, _ = shape_list(embeddings) - - if bool_masked_pos is not None: - mask_tokens = tf.repeat(self.mask_token, batch_size, 0) - mask_tokens = tf.repeat(mask_tokens, seq_len, 1) - # replace the masked visual tokens by mask_tokens - mask = tf.expand_dims(bool_masked_pos, -1) - mask = tf.cast(mask, mask_tokens.dtype) - - embeddings = embeddings * (1.0 - mask) + mask_tokens * mask - - if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings - - embeddings = self.dropout(embeddings, training=training) - - return embeddings, output_dimensions - - -class TFSwinPatchEmbeddings(keras.layers.Layer): - """ - Image to Patch Embedding. - """ - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.embed_dim - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_patches = num_patches - self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) - - self.projection = keras.layers.Conv2D( - filters=hidden_size, - kernel_size=self.patch_size, - strides=self.patch_size, - padding="valid", - name="projection", - ) - - def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tensor: - if width % self.patch_size[1] != 0: - pad_values = ((0, 0), (0, 0), (0, 0), (0, self.patch_size[1] - width % self.patch_size[1])) - pixel_values = tf.pad(pixel_values, pad_values) - if height % self.patch_size[0] != 0: - pad_values = ((0, 0), (0, 0), (0, self.patch_size[0] - height % self.patch_size[0]), (0, 0)) - pixel_values = tf.pad(pixel_values, pad_values) - return pixel_values - - def call(self, pixel_values: tf.Tensor, training: bool = False) -> tuple[tf.Tensor, tuple[int, int]]: - _, num_channels, height, width = shape_list(pixel_values) - if tf.executing_eagerly() and num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - # pad the input to be divisible by self.patch_size, if needed - pixel_values = self.maybe_pad(pixel_values, height, width) - - # B,C,H,W -> B,H,W,C - pixel_values = tf.transpose(pixel_values, (0, 2, 3, 1)) - - embeddings = self.projection(pixel_values, training=training) - - # B,H,W,C -> B,C,H,W - embeddings = tf.transpose(embeddings, (0, 3, 1, 2)) - - batch_size, channels, height, width = shape_list(embeddings) - output_dimensions = (height, width) - - embeddings = tf.reshape(embeddings, (batch_size, channels, -1)) - embeddings = tf.transpose(embeddings, (0, 2, 1)) - return embeddings, output_dimensions - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, None, self.num_channels]) - - -class TFSwinPatchMerging(keras.layers.Layer): - """ - Patch Merging Layer. - - Args: - input_resolution (`tuple[int]`): - Resolution of input feature. - dim (`int`): - Number of input channels. - norm_layer (`keras.layer.Layer`, *optional*, defaults to `keras.layers.LayerNormalization`): - Normalization layer class. - """ - - def __init__( - self, input_resolution: tuple[int, int], dim: int, norm_layer: Callable | None = None, **kwargs - ) -> None: - super().__init__(**kwargs) - self.input_resolution = input_resolution - self.dim = dim - self.reduction = keras.layers.Dense(2 * dim, use_bias=False, name="reduction") - if norm_layer is None: - # Use same default epsilon as PyTorch - self.norm = keras.layers.LayerNormalization(epsilon=1e-5, name="norm") - else: - self.norm = norm_layer(name="norm") - - def maybe_pad(self, input_feature: tf.Tensor, height: int, width: int) -> tf.Tensor: - should_pad = (height % 2 == 1) or (width % 2 == 1) - if should_pad: - pad_values = ((0, 0), (0, height % 2), (0, width % 2), (0, 0)) - input_feature = tf.pad(input_feature, pad_values) - - return input_feature - - def call(self, input_feature: tf.Tensor, input_dimensions: tuple[int, int], training: bool = False) -> tf.Tensor: - height, width = input_dimensions - # `dim` is height * width - batch_size, _, num_channels = shape_list(input_feature) - - input_feature = tf.reshape(input_feature, (batch_size, height, width, num_channels)) - # pad input to be divisible by width and height, if needed - input_feature = self.maybe_pad(input_feature, height, width) - # [batch_size, height/2, width/2, num_channels] - input_feature_0 = input_feature[:, 0::2, 0::2, :] - # [batch_size, height/2, width/2, num_channels] - input_feature_1 = input_feature[:, 1::2, 0::2, :] - # [batch_size, height/2, width/2, num_channels] - input_feature_2 = input_feature[:, 0::2, 1::2, :] - # [batch_size, height/2, width/2, num_channels] - input_feature_3 = input_feature[:, 1::2, 1::2, :] - # batch_size height/2 width/2 4*num_channels - input_feature = tf.concat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) - input_feature = tf.reshape( - input_feature, (batch_size, -1, 4 * num_channels) - ) # batch_size height/2*width/2 4*C - - input_feature = self.norm(input_feature, training=training) - input_feature = self.reduction(input_feature, training=training) - - return input_feature - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "reduction", None) is not None: - with tf.name_scope(self.reduction.name): - self.reduction.build([None, None, 4 * self.dim]) - if getattr(self, "norm", None) is not None: - with tf.name_scope(self.norm.name): - self.norm.build([None, None, 4 * self.dim]) - - -class TFSwinDropPath(keras.layers.Layer): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob: float | None = None, scale_by_keep: bool = True, **kwargs) -> None: - super().__init__(**kwargs) - self.drop_prob = drop_prob - self.scale_by_keep = scale_by_keep - - def call(self, input: tf.Tensor, training: bool = False) -> tf.Tensor: - return drop_path(input, self.drop_prob, training, self.scale_by_keep) - - -class TFSwinSelfAttention(keras.layers.Layer): - def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None: - super().__init__(**kwargs) - if dim % num_heads != 0: - raise ValueError( - f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" - ) - - self.num_attention_heads = num_heads - self.attention_head_size = int(dim / num_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - window_size = config.window_size - self.window_size = ( - window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) - ) - - self.query = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - use_bias=config.qkv_bias, - name="query", - ) - self.key = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - use_bias=config.qkv_bias, - name="key", - ) - self.value = keras.layers.Dense( - self.all_head_size, - kernel_initializer=get_initializer(config.initializer_range), - use_bias=config.qkv_bias, - name="value", - ) - - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) - - def build(self, input_shape: tf.TensorShape) -> None: - self.relative_position_bias_table = self.add_weight( - shape=(((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1)), self.num_attention_heads), - initializer="zeros", - name="relative_position_bias_table", - ) - self.relative_position_index = self.add_weight( - shape=(self.window_size[0] ** 2, self.window_size[1] ** 2), - trainable=False, - dtype=tf.int32, - name="relative_position_index", - ) - - # get pair-wise relative position index for each token inside the window - coords_h = tf.range(self.window_size[0]) - coords_w = tf.range(self.window_size[1]) - coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij")) - coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1)) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = tf.transpose(relative_coords, (1, 2, 0)) - - stack_0, stack_1 = tf.unstack(relative_coords, axis=2) - stack_0 += self.window_size[0] - 1 - stack_0 *= 2 * self.window_size[1] - 1 - stack_1 += self.window_size[1] - 1 - relative_coords = tf.stack([stack_0, stack_1], axis=2) - - self.relative_position_index.assign(tf.cast(tf.reduce_sum(relative_coords, axis=-1), tf.int32)) - - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.all_head_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.all_head_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.all_head_size]) - - def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: - new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size] - x = tf.reshape(x, new_x_shape) - return tf.transpose(x, (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool = False, - training: bool = False, - ) -> tuple[tf.Tensor, ...]: - batch_size, dim, _ = shape_list(hidden_states) - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, (0, 1, 3, 2))) - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - relative_position_bias = tf.gather( - self.relative_position_bias_table, tf.reshape(self.relative_position_index, (-1,)) - ) - relative_position_bias = tf.reshape( - relative_position_bias, - (self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1), - ) - - relative_position_bias = tf.transpose(relative_position_bias, (2, 0, 1)) - attention_scores = attention_scores + tf.expand_dims(relative_position_bias, 0) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in SwinModel call() function) - mask_shape = shape_list(attention_mask)[0] - attention_scores = tf.reshape( - attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim) - ) - attention_mask = tf.expand_dims(attention_mask, 1) - attention_mask = tf.expand_dims(attention_mask, 0) - attention_scores = attention_scores + attention_mask - attention_scores = tf.reshape(attention_scores, (-1, self.num_attention_heads, dim, dim)) - - # Normalize the attention scores to probabilities. - attention_probs = tf.nn.softmax(attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = tf.matmul(attention_probs, value_layer) - context_layer = tf.transpose(context_layer, (0, 2, 1, 3)) - new_context_layer_shape = shape_list(context_layer)[:-2] + [ - self.all_head_size, - ] - context_layer = tf.reshape(context_layer, new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -class TFSwinSelfOutput(keras.layers.Layer): - def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None: - super().__init__(**kwargs) - self.dense = keras.layers.Dense(dim, name="dense") - self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob, name="dropout") - self.dim = dim - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.dim]) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - - -class TFSwinAttention(keras.layers.Layer): - def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None: - super().__init__(**kwargs) - self.self = TFSwinSelfAttention(config, dim, num_heads, name="self") - self.self_output = TFSwinSelfOutput(config, dim, name="output") - self.pruned_heads = set() - - def prune_heads(self, heads): - """ - Prunes heads of the model. See base class PreTrainedModel heads: dict of {layer_num: list of heads to prune in - this layer} - """ - raise NotImplementedError - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool = False, - training: bool = False, - ) -> tf.Tensor: - self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions, training=training) - attention_output = self.self_output(self_outputs[0], hidden_states, training=training) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self", None) is not None: - with tf.name_scope(self.self.name): - self.self.build(None) - if getattr(self, "self_output", None) is not None: - with tf.name_scope(self.self_output.name): - self.self_output.build(None) - - -class TFSwinIntermediate(keras.layers.Layer): - def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None: - super().__init__(**kwargs) - self.dense = keras.layers.Dense(int(config.mlp_ratio * dim), name="dense") - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - self.dim = dim - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.dim]) - - -class TFSwinOutput(keras.layers.Layer): - def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None: - super().__init__(**kwargs) - self.dense = keras.layers.Dense(dim, name="dense") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, "dropout") - self.config = config - self.dim = dim - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, int(self.config.mlp_ratio * self.dim)]) - - -class TFSwinLayer(keras.layers.Layer): - def __init__( - self, - config, - dim, - input_resolution: tuple[int, int], - num_heads: int, - drop_path_rate: float = 0.0, - shift_size: int = 0, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.chunk_size_feed_forward = config.chunk_size_feed_forward - min_res = tf.reduce_min(input_resolution) - self.window_size = min_res if min_res <= config.window_size else config.window_size - self.shift_size = 0 if min_res <= self.window_size else shift_size - self.input_resolution = input_resolution - - self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") - self.attention = TFSwinAttention(config, dim, num_heads, name="attention") - self.drop_path = ( - TFSwinDropPath(drop_path_rate, name="drop_path") - if drop_path_rate > 0.0 - else keras.layers.Activation("linear", name="drop_path") - ) - self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") - self.intermediate = TFSwinIntermediate(config, dim, name="intermediate") - self.swin_output = TFSwinOutput(config, dim, name="output") - self.dim = dim - - def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: int) -> tf.Tensor | None: - img_mask = tf.zeros((height, width)) - height_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1)) - width_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1)) - - # calculate attention mask for SW-MSA - if shift_size > 0: - count = 0 - for height_slice in height_slices: - for width_slice in width_slices: - height_inds = tf.range(height_slice[0] % height, height_slice[1] % height + 1) - width_inds = tf.range(width_slice[0] % width, width_slice[1] % width + 1) - indices = tf.reshape(tf.stack(tf.meshgrid(height_inds, width_inds), axis=-1), (-1, 2)) - if len(indices) >= 1: - updates = tf.ones((len(indices),), dtype=img_mask.dtype) * count - img_mask = tf.tensor_scatter_nd_update(img_mask, indices, updates) - count += 1 - - img_mask = tf.expand_dims(img_mask, -1) - img_mask = tf.expand_dims(img_mask, 0) - - mask_windows = window_partition(img_mask, window_size) - mask_windows = tf.reshape(mask_windows, (-1, window_size * window_size)) - attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2) - attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask) - attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask) - return attn_mask - - def maybe_pad( - self, hidden_states: tf.Tensor, window_size: int, height: int, width: int - ) -> tuple[tf.Tensor, tf.Tensor]: - pad_right = (window_size - width % window_size) % window_size - pad_bottom = (window_size - height % window_size) % window_size - pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]] - hidden_states = tf.pad(hidden_states, pad_values) - pad_values = tf.reshape(pad_values, (-1,)) - return hidden_states, pad_values - - def call( - self, - hidden_states: tf.Tensor, - input_dimensions: tuple[int, int], - head_mask: tf.Tensor | None = None, - output_attentions: bool = False, - training: bool = False, - ) -> tf.Tensor: - # if window size is larger than input resolution, we don't partition windows - min_res = tf.reduce_min(input_dimensions) - shift_size = 0 if min_res <= self.window_size else self.shift_size - window_size = min_res if min_res <= self.window_size else self.window_size - - height, width = input_dimensions - batch_size, _, channels = shape_list(hidden_states) - shortcut = hidden_states - - hidden_states = self.layernorm_before(hidden_states, training=training) - hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels)) - # pad hidden_states to multiples of window size - hidden_states, pad_values = self.maybe_pad(hidden_states, window_size, height, width) - - _, height_pad, width_pad, _ = shape_list(hidden_states) - # cyclic shift - if shift_size > 0: - shifted_hidden_states = tf.roll(hidden_states, shift=(-shift_size, -shift_size), axis=(1, 2)) - else: - shifted_hidden_states = hidden_states - - # partition windows - hidden_states_windows = window_partition(shifted_hidden_states, window_size) - hidden_states_windows = tf.reshape(hidden_states_windows, (-1, window_size * window_size, channels)) - attn_mask = self.get_attn_mask( - height=height_pad, width=width_pad, window_size=window_size, shift_size=shift_size - ) - - attention_outputs = self.attention( - hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions, training=training - ) - - attention_output = attention_outputs[0] - - attention_windows = tf.reshape(attention_output, (-1, window_size, window_size, channels)) - shifted_windows = window_reverse(attention_windows, window_size, height_pad, width_pad) - - # reverse cyclic shift - if shift_size > 0: - attention_windows = tf.roll(shifted_windows, shift=(shift_size, shift_size), axis=(1, 2)) - else: - attention_windows = shifted_windows - - was_padded = pad_values[3] > 0 or pad_values[5] > 0 - if was_padded: - attention_windows = attention_windows[:, :height, :width, :] - - attention_windows = tf.reshape(attention_windows, (batch_size, height * width, channels)) - - hidden_states = shortcut + self.drop_path(attention_windows, training=training) - - layer_output = self.layernorm_after(hidden_states, training=training) - layer_output = self.intermediate(layer_output) - layer_output = hidden_states + self.swin_output(layer_output, training=training) - - layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) - return layer_outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layernorm_before", None) is not None: - with tf.name_scope(self.layernorm_before.name): - self.layernorm_before.build([None, None, self.dim]) - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "drop_path", None) is not None: - with tf.name_scope(self.drop_path.name): - self.drop_path.build(None) - if getattr(self, "layernorm_after", None) is not None: - with tf.name_scope(self.layernorm_after.name): - self.layernorm_after.build([None, None, self.dim]) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "swin_output", None) is not None: - with tf.name_scope(self.swin_output.name): - self.swin_output.build(None) - - -class TFSwinStage(keras.layers.Layer): - def __init__( - self, - config: SwinConfig, - dim: int, - input_resolution: tuple[int, int], - depth: int, - num_heads: int, - drop_path: list[float], - downsample: Callable | None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.config = config - self.dim = dim - self.blocks = [ - TFSwinLayer( - config=config, - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - shift_size=0 if (i % 2 == 0) else config.window_size // 2, - drop_path_rate=drop_path[i], - name=f"blocks.{i}", - ) - for i in range(depth) - ] - - # patch merging layer - if downsample is not None: - self.downsample = downsample( - input_resolution, - dim=dim, - norm_layer=partial(keras.layers.LayerNormalization, epsilon=1e-5), - name="downsample", - ) - else: - self.downsample = None - - self.pointing = False - - def call( - self, - hidden_states: tf.Tensor, - input_dimensions: tuple[int, int], - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = False, - training: bool = False, - ) -> tuple[tf.Tensor, ...]: - height, width = input_dimensions - for i, layer_module in enumerate(self.blocks): - layer_head_mask = head_mask[i] if head_mask is not None else None - - layer_outputs = layer_module( - hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training - ) - - hidden_states = layer_outputs[0] - - if self.downsample is not None: - height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 - output_dimensions = (height, width, height_downsampled, width_downsampled) - hidden_states = self.downsample(layer_outputs[0], input_dimensions, training=training) - else: - output_dimensions = (height, width, height, width) - - stage_outputs = (hidden_states, output_dimensions) - - if output_attentions: - stage_outputs += layer_outputs[1:] - return stage_outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "downsample", None) is not None: - with tf.name_scope(self.downsample.name): - self.downsample.build(None) - if getattr(self, "blocks", None) is not None: - for layer in self.blocks: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFSwinEncoder(keras.layers.Layer): - def __init__(self, config: SwinConfig, grid_size: tuple[int, int], **kwargs): - super().__init__(**kwargs) - self.num_layers = len(config.depths) - self.config = config - dpr = list((tf.linspace(0, 1, sum(config.depths)) * config.drop_path_rate).numpy()) - self.layers = [ - TFSwinStage( - config=config, - dim=int(config.embed_dim * 2**i_layer), - input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), - depth=config.depths[i_layer], - num_heads=config.num_heads[i_layer], - drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], - downsample=TFSwinPatchMerging if (i_layer < self.num_layers - 1) else None, - name=f"layers.{i_layer}", - ) - for i_layer in range(self.num_layers) - ] - - self.gradient_checkpointing = False - - def call( - self, - hidden_states: tf.Tensor, - input_dimensions: tuple[int, int], - head_mask: tf.Tensor | None = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - training: bool = False, - ) -> tuple[tf.Tensor, ...] | TFSwinEncoderOutput: - all_input_dimensions = () - all_hidden_states = () if output_hidden_states else None - all_reshaped_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - if output_hidden_states: - batch_size, _, hidden_size = shape_list(hidden_states) - # rearrange b (h w) c -> b c h w - reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size)) - reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2)) - all_hidden_states += (hidden_states,) - all_reshaped_hidden_states += (reshaped_hidden_state,) - - for i, layer_module in enumerate(self.layers): - layer_head_mask = head_mask[i] if head_mask is not None else None - - layer_outputs = layer_module( - hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training - ) - - hidden_states = layer_outputs[0] - output_dimensions = layer_outputs[1] - - input_dimensions = (output_dimensions[-2], output_dimensions[-1]) - all_input_dimensions += (input_dimensions,) - - if output_hidden_states: - batch_size, _, hidden_size = shape_list(hidden_states) - # rearrange b (h w) c -> b c h w - reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size)) - reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2)) - all_hidden_states += (hidden_states,) - all_reshaped_hidden_states += (reshaped_hidden_state,) - - if output_attentions: - all_self_attentions += layer_outputs[2:] - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - - return TFSwinEncoderOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - reshaped_hidden_states=all_reshaped_hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFSwinPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = SwinConfig - base_model_prefix = "swin" - main_input_name = "pixel_values" - - -SWIN_START_DOCSTRING = r""" - This model is a Tensorflow - [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a - regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`SwinConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -SWIN_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] - for details. - head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def normalize_data_format(value: str) -> str: - """ - From tensorflow addons - https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/utils/keras_utils.py#L71 - """ - if value is None: - value = keras.backend.image_data_format() - data_format = value.lower() - if data_format not in {"channels_first", "channels_last"}: - raise ValueError( - 'The `data_format` argument must be one of "channels_first", "channels_last". Received: ' + str(value) - ) - return data_format - - -class AdaptiveAveragePooling1D(keras.layers.Layer): - """ - Args: - Average 1D Pooling with adaptive kernel size. - output_size: An integer or tuple/list of a single integer, specifying pooled_features. - The new size of output channels. - data_format: A string, - one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape `(batch, steps, channels)` while `channels_first` corresponds - to inputs with shape `(batch, channels, steps)`. - Input shape: - - If `data_format='channels_last'`: 3D tensor with shape `(batch, steps, channels)`. - - If `data_format='channels_first'`: 3D tensor with shape `(batch, channels, steps)`. - Output shape: - - If `data_format='channels_last'`: 3D tensor with shape `(batch_size, pooled_steps, channels)`. - - If `data_format='channels_first'`: 3D tensor with shape `(batch_size, channels, pooled_steps)`. - - Adapted from [tensorflow-addon's adaptive pooling.py]( - https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/layers/adaptive_pooling.py#L90-L120 - ) - """ - - def __init__( - self, - output_size: int | Iterable[int], - reduce_function: Callable = tf.reduce_mean, - data_format: str | None = None, - **kwargs, - ) -> None: - self.data_format = normalize_data_format(data_format) - self.reduce_function = reduce_function - self.output_size = (output_size,) if isinstance(output_size, int) else tuple(output_size) - super().__init__(**kwargs) - - def call(self, inputs: tf.Tensor, *args) -> None: - bins = self.output_size[0] - if self.data_format == "channels_last": - splits = tf.split(inputs, bins, axis=1) - splits = tf.stack(splits, axis=1) - out_vect = self.reduce_function(splits, axis=2) - else: - splits = tf.split(inputs, bins, axis=2) - splits = tf.stack(splits, axis=2) - out_vect = self.reduce_function(splits, axis=3) - return out_vect - - def compute_output_shape(self, input_shape: Iterable[int]) -> tf.TensorShape: - input_shape = tf.TensorShape(input_shape).as_list() - if self.data_format == "channels_last": - shape = tf.TensorShape([input_shape[0], self.output_size[0], input_shape[2]]) - else: - shape = tf.TensorShape([input_shape[0], input_shape[1], self.output_size[0]]) - return shape - - def get_config(self) -> dict[str, Any]: - config = { - "output_size": self.output_size, - "data_format": self.data_format, - } - base_config = super().get_config() - return {**base_config, **config} - - -@keras_serializable -class TFSwinMainLayer(keras.layers.Layer): - config_class = SwinConfig - - def __init__( - self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs - ) -> None: - super().__init__(**kwargs) - self.config = config - self.num_layers = len(config.depths) - self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) - - self.embeddings = TFSwinEmbeddings(config, use_mask_token=use_mask_token, name="embeddings") - self.encoder = TFSwinEncoder(config, self.embeddings.patch_grid, name="encoder") - - self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - self.pooler = AdaptiveAveragePooling1D(output_size=(1,)) if add_pooling_layer else None - - def get_input_embeddings(self) -> TFSwinPatchEmbeddings: - return self.embeddings.patch_embeddings - - def _prune_heads(self, heads_to_prune: dict[int, list]): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - def get_head_mask(self, head_mask: Any | None) -> list: - if head_mask is not None: - raise NotImplementedError - return [None] * len(self.config.depths) - - @unpack_inputs - def call( - self, - pixel_values: tf.Tensor | None = None, - bool_masked_pos: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFSwinModelOutput | tuple[tf.Tensor, ...]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask) - embedding_output, input_dimensions = self.embeddings( - pixel_values, bool_masked_pos=bool_masked_pos, training=training - ) - - encoder_outputs = self.encoder( - embedding_output, - input_dimensions, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output, training=training) - - pooled_output = None - if self.pooler is not None: - batch_size, _, num_features = shape_list(sequence_output) - pooled_output = self.pooler(sequence_output) - pooled_output = tf.reshape(pooled_output, (batch_size, num_features)) - - if not return_dict: - output = (sequence_output, pooled_output) + encoder_outputs[1:] - return output - - return TFSwinModelOutput( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, self.num_features]) - - -@add_start_docstrings( - "The bare Swin Model transformer outputting raw hidden-states without any specific head on top.", - SWIN_START_DOCSTRING, -) -class TFSwinModel(TFSwinPreTrainedModel): - def __init__( - self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs - ) -> None: - super().__init__(config, **kwargs) - self.config = config - self.swin = TFSwinMainLayer(config, name="swin") - - @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSwinModelOutput, - config_class=_CONFIG_FOR_DOC, - modality="vision", - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - @unpack_inputs - def call( - self, - pixel_values: tf.Tensor | None = None, - bool_masked_pos: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFSwinModelOutput | tuple[tf.Tensor, ...]: - r""" - bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`, *optional*): - Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - swin_outputs = self.swin( - pixel_values=pixel_values, - bool_masked_pos=bool_masked_pos, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return swin_outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "swin", None) is not None: - with tf.name_scope(self.swin.name): - self.swin.build(None) - - -class TFSwinPixelShuffle(keras.layers.Layer): - """TF layer implementation of torch.nn.PixelShuffle""" - - def __init__(self, upscale_factor: int, **kwargs) -> None: - super().__init__(**kwargs) - if not isinstance(upscale_factor, int) or upscale_factor < 2: - raise ValueError(f"upscale_factor must be an integer value >= 2 got {upscale_factor}") - self.upscale_factor = upscale_factor - - def call(self, x: tf.Tensor) -> tf.Tensor: - hidden_states = x - batch_size, _, _, num_input_channels = shape_list(hidden_states) - block_size_squared = self.upscale_factor**2 - output_depth = int(num_input_channels / block_size_squared) - # When the number of output channels >= 2, PyTorch's PixelShuffle and - # TF's depth_to_space differ in their output as the order of channels selected for combining - # is a permutation of the other c.f. - # https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1 - permutation = tf.constant( - [[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]] - ) - hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1) - hidden_states = tf.nn.depth_to_space(hidden_states, block_size=self.upscale_factor, data_format="NHWC") - return hidden_states - - -class TFSwinDecoder(keras.layers.Layer): - def __init__(self, config: SwinConfig, **kwargs): - super().__init__(**kwargs) - self.conv2d = keras.layers.Conv2D( - filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, strides=1, name="0" - ) - self.pixel_shuffle = TFSwinPixelShuffle(config.encoder_stride, name="1") - self.config = config - - def call(self, x: tf.Tensor) -> tf.Tensor: - hidden_states = x - # B,C,H,W -> B,H,W,C - hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1)) - hidden_states = self.conv2d(hidden_states) - hidden_states = self.pixel_shuffle(hidden_states) - # B,H,W,C -> B,C,H,W - hidden_states = tf.transpose(hidden_states, (0, 3, 1, 2)) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv2d", None) is not None: - with tf.name_scope(self.conv2d.name): - self.conv2d.build([None, None, None, self.config.hidden_size]) - if getattr(self, "pixel_shuffle", None) is not None: - with tf.name_scope(self.pixel_shuffle.name): - self.pixel_shuffle.build(None) - - -@add_start_docstrings( - "Swin Model with a decoder on top for masked image modeling, as proposed in" - " [SimMIM](https://huggingface.co/papers/2111.09886).", - SWIN_START_DOCSTRING, -) -class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel): - def __init__(self, config: SwinConfig): - super().__init__(config) - - self.swin = TFSwinMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="swin") - - self.decoder = TFSwinDecoder(config, name="decoder") - - @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) - @unpack_inputs - def call( - self, - pixel_values: tf.Tensor | None = None, - bool_masked_pos: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple | TFSwinMaskedImageModelingOutput: - r""" - bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`): - Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). - - Returns: - - Examples: - ```python - >>> from transformers import AutoImageProcessor, TFSwinForMaskedImageModeling - >>> import tensorflow as tf - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224") - >>> model = TFSwinForMaskedImageModeling.from_pretrained("microsoft/swin-tiny-patch4-window7-224") - - >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 - >>> pixel_values = image_processor(images=image, return_tensors="tf").pixel_values - >>> # create random boolean mask of shape (batch_size, num_patches) - >>> bool_masked_pos = tf.random.uniform((1, num_patches)) >= 0.5 - - >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) - >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction - >>> list(reconstructed_pixel_values.shape) - [1, 3, 224, 224] - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.swin( - pixel_values, - bool_masked_pos=bool_masked_pos, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - # Reshape to (batch_size, num_channels, height, width) - sequence_output = tf.transpose(sequence_output, (0, 2, 1)) - batch_size, num_channels, sequence_length = shape_list(sequence_output) - height = width = int(sequence_length**0.5) - sequence_output = tf.reshape(sequence_output, (batch_size, num_channels, height, width)) - - # Reconstruct pixel values - reconstructed_pixel_values = self.decoder(sequence_output) - - masked_im_loss = None - if bool_masked_pos is not None: - size = self.config.image_size // self.config.patch_size - bool_masked_pos = tf.reshape(bool_masked_pos, (-1, size, size)) - mask = tf.repeat(bool_masked_pos, self.config.patch_size, 1) - mask = tf.repeat(mask, self.config.patch_size, 2) - mask = tf.expand_dims(mask, 1) - mask = tf.cast(mask, tf.float32) - - reconstruction_loss = keras.losses.mean_absolute_error( - # Swap axes as metric calculation reduces over the final dimension - tf.transpose(pixel_values, (1, 2, 3, 0)), - tf.transpose(reconstructed_pixel_values, (1, 2, 3, 0)), - ) - reconstruction_loss = tf.expand_dims(reconstruction_loss, 0) - total_loss = tf.reduce_sum(reconstruction_loss * mask) - num_masked_pixels = (tf.reduce_sum(mask) + 1e-5) * self.config.num_channels - masked_im_loss = total_loss / num_masked_pixels - masked_im_loss = tf.reshape(masked_im_loss, (1,)) - - if not return_dict: - output = (reconstructed_pixel_values,) + outputs[2:] - return ((masked_im_loss,) + output) if masked_im_loss is not None else output - - return TFSwinMaskedImageModelingOutput( - loss=masked_im_loss, - reconstruction=reconstructed_pixel_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - reshaped_hidden_states=outputs.reshaped_hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "swin", None) is not None: - with tf.name_scope(self.swin.name): - self.swin.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - """ - Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of - the [CLS] token) e.g. for ImageNet. - """, - SWIN_START_DOCSTRING, -) -class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: SwinConfig): - super().__init__(config) - - self.num_labels = config.num_labels - self.swin = TFSwinMainLayer(config, name="swin") - - # Classifier head - self.classifier = ( - keras.layers.Dense(config.num_labels, name="classifier") - if config.num_labels > 0 - else keras.layers.Activation("linear", name="classifier") - ) - - @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=TFSwinImageClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, - ) - @unpack_inputs - def call( - self, - pixel_values: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - labels: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple[tf.Tensor, ...] | TFSwinImageClassifierOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.swin( - pixel_values, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - pooled_output = outputs[1] - - logits = self.classifier(pooled_output, training=training) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSwinImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - reshaped_hidden_states=outputs.reshaped_hidden_states, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "swin", None) is not None: - with tf.name_scope(self.swin.name): - self.swin.build(None) - if getattr(self, "classifier", None) is not None: - if hasattr(self.classifier, "name"): - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.swin.num_features]) - - -__all__ = ["TFSwinForImageClassification", "TFSwinForMaskedImageModeling", "TFSwinModel", "TFSwinPreTrainedModel"] diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py deleted file mode 100644 index 71d304ea96c6..000000000000 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ /dev/null @@ -1,203 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Convert SwitchTransformersX checkpoints from the original repository to JAX/FLAX model.""" - -import argparse -import re - -from flax.traverse_util import flatten_dict, unflatten_dict -from t5x import checkpoints - -from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration -from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model -from transformers.utils import logging - - -logging.set_verbosity_info() - - -# should not include what is already done by the `from_pt` argument -MOE_LAYER_NAME_MAPPING = { - "/attention/": "/0/SelfAttention/", - "/self_attention/": "/0/SelfAttention/", - "/encoder_decoder_attention/": "/1/EncDecAttention/", - "value": "v", - "query": "q", - "key": "k", - "out": "o", - "pre_self_attention_layer_norm": "0/layer_norm", - "pre_cross_attention_layer_norm": "1/layer_norm", - "pre_attention_layer_norm": "0/layer_norm", # previously 1, but seems wrong - "token_embedder": "shared", - "encoder_norm": "final_layer_norm", - "decoder_norm": "final_layer_norm", - "relpos_bias/rel_embedding": "block/0/layer/0/SelfAttention/relative_attention_bias/weight", - "router/router_weights/w/": "router/classifier/", - "roer/roer_weights/w/": "router/classifier/", - "logits_dense": "lm_head", -} - - -def rename_keys(s_dict): - # 1. in HF T5, we have block.{x}.layer.{y}. which corresponds to layer.{x} in - # the original model - keys = list(s_dict.keys()) - for key in keys: - layer_to_block_of_layer = r".*/layers_(\d+)" - new_key = key - if re.match(layer_to_block_of_layer, key): - new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", new_key) - - layer_to_block_of_layer = r"(encoder|decoder)\/" - - if re.match(layer_to_block_of_layer, key): - groups = re.match(layer_to_block_of_layer, new_key).groups() - if groups[0] == "encoder": - new_key = re.sub(r"/mlp/", r"/1/mlp/", new_key) - new_key = re.sub(r"/pre_mlp_layer_norm/", r"/1/layer_norm/", new_key) - - elif groups[0] == "decoder": - new_key = re.sub(r"/mlp/", r"/2/mlp/", new_key) - new_key = re.sub(r"/pre_mlp_layer_norm/", r"/2/layer_norm/", new_key) - - # 2. Convert other classic mappings - for old_key, temp_key in MOE_LAYER_NAME_MAPPING.items(): - if old_key in new_key: - new_key = new_key.replace(old_key, temp_key) - - print(f"{key} -> {new_key}") - s_dict[new_key] = s_dict.pop(key) - - if "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict: - s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ - "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" - ].T - if "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict: - s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ - "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" - ].T - - # 3. Take extra care of the EXPERTS layer - for key in list(s_dict.keys()): - if "expert" in key: - num_experts = s_dict[key].shape[0] - expert_weihts = s_dict[key] - for idx in range(num_experts): - s_dict[key.replace("expert/", f"experts/expert_{idx}/")] = expert_weihts[idx] - print(f"{key} -> {key.replace('expert/', f'experts/expert_{idx}/')}") - - s_dict.pop(key) - - return s_dict - - -GIN_TO_CONFIG_MAPPING = { - "NUM_ENCODER_LAYERS": "num_layers", - "NUM_DECODER_LAYERS": "num_decoder_layers", - "NUM_HEADS": "num_heads", - "HEAD_DIM": "d_kv", - "EMBED_DIM": "d_model", - "MLP_DIM": "d_ff", - "NUM_SELECTED_EXPERTS": "num_selected_experts", - "NUM_ENCODER_SPARSE_LAYERS": "num_sparse_encoder_layers", - "NUM_DECODER_SPARSE_LAYERS": "num_sparse_decoder_layers", - "dense.MlpBlock.activations": "feed_forward_proj", -} - - -def convert_gin_to_config(gin_file, num_experts): - # Convert a google style config to the hugging face format - import regex as re - - with open(gin_file, "r") as f: - raw_gin = f.read() - - regex_match = re.findall(r"(.*) = ([0-9.]*)", raw_gin) - args = {} - for param, value in regex_match: - if param in GIN_TO_CONFIG_MAPPING and value != "": - args[GIN_TO_CONFIG_MAPPING[param]] = float(value) if "." in value else int(value) - - activation = re.findall(r"(.*activations) = \(\'(.*)\',\)", raw_gin)[0] - args[GIN_TO_CONFIG_MAPPING[activation[0]]] = str(activation[1]) - - args["num_experts"] = num_experts - config = SwitchTransformersConfig(**args) - return config - - -def convert_flax_checkpoint_to_pytorch( - flax_checkpoint_path, config_file, gin_file=None, pytorch_dump_path="./", num_experts=8 -): - # Initialise PyTorch model - - print(f"Loading flax weights from : {flax_checkpoint_path}") - flax_params = checkpoints.load_t5x_checkpoint(flax_checkpoint_path) - - if gin_file is not None: - config = convert_gin_to_config(gin_file, num_experts) - else: - config = SwitchTransformersConfig.from_pretrained(config_file) - - pt_model = SwitchTransformersForConditionalGeneration(config) - - flax_params = flax_params["target"] - flax_params = flatten_dict(flax_params, sep="/") - flax_params = rename_keys(flax_params) - flax_params = unflatten_dict(flax_params, sep="/") - - # Load the flax params in the PT model - load_flax_weights_in_pytorch_model(pt_model, flax_params) - - print(f"Save PyTorch model to {pytorch_dump_path}") - pt_model.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--switch_t5x_checkpoint_path", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained SwitchTransformers model. \nThis specifies the" - " model architecture. If not provided, a `gin_file` has to be provided." - ), - ) - parser.add_argument( - "--gin_file", - default=None, - type=str, - required=False, - help="Path to the gin config file. If not provided, a `config_file` has to be passed ", - ) - parser.add_argument( - "--config_name", default=None, type=str, required=False, help="Config name of SwitchTransformers model." - ) - parser.add_argument( - "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output pytorch model." - ) - parser.add_argument("--num_experts", default=8, type=int, required=False, help="Number of experts") - args = parser.parse_args() - convert_flax_checkpoint_to_pytorch( - args.switch_t5x_checkpoint_path, - args.config_name, - args.gin_file, - args.pytorch_dump_folder_path, - args.num_experts, - ) diff --git a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py deleted file mode 100755 index 9b1b15857cea..000000000000 --- a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,59 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The T5 authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert T5 checkpoint.""" - -import argparse - -from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): - # Initialise PyTorch model - config = T5Config.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") - model = T5ForConditionalGeneration(config) - - # Load weights from tf checkpoint - load_tf_weights_in_t5(model, config, tf_checkpoint_path) - - # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py deleted file mode 100644 index c829a084e4d9..000000000000 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ /dev/null @@ -1,1801 +0,0 @@ -# coding=utf-8 -# Copyright 2021 T5 Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax T5 model.""" - -import copy -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen import partitioning as nn_partitioning -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax.random import PRNGKey - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxSeq2SeqLMOutput, - FlaxSeq2SeqModelOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_t5 import T5Config - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "google-t5/t5-small" -_CONFIG_FOR_DOC = "T5Config" - -remat = nn_partitioning.remat - - -# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: - """ - Shift input ids one token to the right. - """ - shifted_input_ids = jnp.zeros_like(input_ids) - shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) - shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) - - shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) - return shifted_input_ids - - -class FlaxT5LayerNorm(nn.Module): - hidden_size: int - dtype: jnp.dtype = jnp.float32 - eps: float = 1e-6 - weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones - - def setup(self): - self.weight = self.param("weight", self.weight_init, (self.hidden_size,)) - - def __call__(self, hidden_states): - """ - Construct a layernorm module in the T5 style; No bias and no subtraction of mean. - """ - # layer norm should always be calculated in float32 - variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) - hidden_states = hidden_states / jnp.sqrt(variance + self.eps) - - return self.weight * hidden_states - - -class FlaxT5DenseActDense(nn.Module): - config: T5Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) - wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) - - self.wi = nn.Dense( - self.config.d_ff, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std), - dtype=self.dtype, - ) - self.wo = nn.Dense( - self.config.d_model, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wo_init_std), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - self.act = ACT2FN[self.config.dense_act_fn] - - def __call__(self, hidden_states, deterministic=True): - hidden_states = self.wi(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.wo(hidden_states) - return hidden_states - - -class FlaxT5DenseGatedActDense(nn.Module): - config: T5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) - wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) - - self.wi_0 = nn.Dense( - self.config.d_ff, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std), - dtype=self.dtype, - ) - self.wi_1 = nn.Dense( - self.config.d_ff, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std), - dtype=self.dtype, - ) - self.wo = nn.Dense( - self.config.d_model, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wo_init_std), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - self.act = ACT2FN[self.config.dense_act_fn] - - def __call__(self, hidden_states, deterministic): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.wo(hidden_states) - return hidden_states - - -class FlaxT5LayerFF(nn.Module): - config: T5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - if self.config.is_gated_act: - self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype) - else: - self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype) - - self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__(self, hidden_states, deterministic=True): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic) - hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic) - return hidden_states - - -class FlaxT5Attention(nn.Module): - config: T5Config - has_relative_attention_bias: bool = False - causal: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.relative_attention_num_buckets = self.config.relative_attention_num_buckets - self.relative_attention_max_distance = self.config.relative_attention_max_distance - self.d_model = self.config.d_model - self.key_value_proj_dim = self.config.d_kv - self.n_heads = self.config.num_heads - self.dropout = self.config.dropout_rate - self.inner_dim = self.n_heads * self.key_value_proj_dim - - q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) - kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) - o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) - - self.q = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(q_init_std), - dtype=self.dtype, - ) - self.k = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - self.v = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - self.o = nn.Dense( - self.d_model, - use_bias=False, - kernel_init=jax.nn.initializers.normal(o_init_std), - dtype=self.dtype, - ) - - if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embed( - self.relative_attention_num_buckets, - self.n_heads, - embedding_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - - @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0) * num_buckets - relative_position = jnp.abs(relative_position) - else: - relative_position = -jnp.clip(relative_position, a_max=0) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) - ) - relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) - - relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) - - return relative_buckets.astype("i4") - - def compute_bias(self, query_length, key_length): - """Compute binned relative position bias""" - context_position = jnp.arange(query_length, dtype="i4")[:, None] - memory_position = jnp.arange(key_length, dtype="i4")[None, :] - - relative_position = memory_position - context_position - relative_position_bucket = self._relative_position_bucket( - relative_position, - bidirectional=(not self.causal), - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - - values = self.relative_attention_bias(relative_position_bucket) - values = values.transpose((2, 0, 1))[None, :, :, :] - return values - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = jax.lax.dynamic_update_slice(cached_key.value, key, indices) - value = jax.lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions - # that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def _create_position_bias( - self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift - ): - cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache) - key_length = key_states.shape[1] - query_length = key_length if cache_is_filled else query_states.shape[1] - - if self.has_relative_attention_bias: - position_bias = self.compute_bias(query_length, key_length) - elif attention_mask is not None: - position_bias = jnp.zeros_like(attention_mask) - else: - position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype) - - # if key and values are already calculated, only the last query position bias should be taken - if cache_is_filled: - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - position_bias = jax.lax.dynamic_slice( - position_bias, - (0, 0, causal_attention_mask_shift, 0), - (1, self.n_heads, seq_length, max_decoder_length), - ) - return position_bias - - def __call__( - self, - hidden_states, - attention_mask=None, - key_value_states=None, - position_bias=None, - use_cache=False, - output_attentions=False, - deterministic=True, - init_cache=False, - ): - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - batch_size, seq_length = hidden_states.shape[:2] - - # q, k, v projections - query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) - key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) - value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) - - # reshape to (batch_size, seq_length, n_heads, head_dim) - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # counter-act scaling in dot_product_attention_weights function - query_states *= jnp.sqrt(query_states.shape[-1]) - - # for fast decoding causal attention mask should be shifted - causal_attention_mask_shift = ( - self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0 - ) - # create causal attention_mask; attention_mask has to be defined when model is causal - if self.causal: - causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") - - # fast decoding for generate requires special attention_mask - if self.has_variable("cache", "cached_key"): - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_attention_mask = jax.lax.dynamic_slice( - causal_attention_mask, - (0, 0, causal_attention_mask_shift, 0), - (1, 1, seq_length, max_decoder_length), - ) - - # broadcast causal attention mask & attention mask to fit for merge - causal_attention_mask = jnp.broadcast_to( - causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] - ) - attention_mask = jnp.broadcast_to( - jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape - ) - attention_mask = combine_masks(attention_mask, causal_attention_mask) - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # replace masked positions with -10_000 - if attention_mask is not None: - mask_value = jnp.finfo(self.dtype).min - attention_mask = jax.lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, mask_value).astype(self.dtype), - ) - - if position_bias is None: - # compute position bias (only for first layer) - position_bias = self._create_position_bias( - key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift - ) - - if attention_mask is not None: - position_bias = position_bias + attention_mask - - # create dropout rng - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - # Softmax(QK^T) - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=position_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - ) - - # multiply with value states - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - - # bring back to (batch_size, seq_length, d_model) - attn_output = self._merge_heads(attn_output) - - # apply output matrix - attn_output = self.o(attn_output) - - outputs = (attn_output, position_bias) - - if output_attentions: - outputs = outputs + (attn_weights,) - - return outputs - - -class FlaxT5LayerSelfAttention(nn.Module): - config: T5Config - has_relative_attention_bias: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.SelfAttention = FlaxT5Attention( - self.config, - has_relative_attention_bias=self.has_relative_attention_bias, - causal=self.config.causal, - dtype=self.dtype, - ) - self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_bias=None, - output_attentions=False, - deterministic=True, - init_cache=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - normed_hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - init_cache=init_cache, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - -class FlaxT5LayerCrossAttention(nn.Module): - config: T5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.EncDecAttention = FlaxT5Attention( - self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype - ) - self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - output_attentions=False, - deterministic=True, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.EncDecAttention( - normed_hidden_states, - attention_mask=attention_mask, - key_value_states=key_value_states, - position_bias=position_bias, - output_attentions=output_attentions, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - -class FlaxT5Block(nn.Module): - config: T5Config - has_relative_attention_bias: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.causal = self.config.causal - self.layer = ( - FlaxT5LayerSelfAttention( - self.config, - has_relative_attention_bias=self.has_relative_attention_bias, - name=str(0), - dtype=self.dtype, - ), - ) - feed_forward_index = 1 - if self.causal: - self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) - feed_forward_index += 1 - - self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - output_attentions=False, - return_dict=True, - deterministic=True, - init_cache=False, - ): - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - init_cache=init_cache, - ) - hidden_states = self_attention_outputs[0] - attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights - - do_cross_attention = self.causal and encoder_hidden_states is not None - if do_cross_attention: - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - ) - hidden_states = cross_attention_outputs[0] - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[1:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - outputs = outputs + attention_outputs - - # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - return outputs - - -class FlaxT5LayerCollection(nn.Module): - config: T5Config - has_relative_attention_bias: bool - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layer = FlaxT5Block( - self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype - ) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - output_attentions=False, - deterministic=True, - init_cache=False, - ): - return self.layer( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - init_cache=init_cache, - ) - - -class FlaxT5BlockCollection(nn.Module): - config: T5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.causal = self.config.causal - if self.gradient_checkpointing: - FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8)) - self.blocks = [ - FlaxT5CheckpointLayer( - self.config, - has_relative_attention_bias=(i == 0), - dtype=self.dtype, - name=str(i), - ) - for i in range(self.config.num_layers) - ] - else: - self.blocks = [ - FlaxT5LayerCollection( - self.config, - has_relative_attention_bias=(i == 0), - dtype=self.dtype, - name=str(i), - ) - for i in range(self.config.num_layers) - ] - - def __call__( - self, - hidden_states=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - output_attentions: bool = False, - output_hidden_states: bool = False, - deterministic: bool = True, - init_cache: bool = False, - ): - # Prepare head mask if needed - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if (output_attentions and self.causal) else None - position_bias = None - encoder_decoder_position_bias = None - - for i, layer_module in enumerate(self.blocks): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask, - position_bias, - encoder_hidden_states, - encoder_attention_mask, - encoder_decoder_position_bias, - output_attentions, - deterministic, - init_cache, - ) - - hidden_states = layer_outputs[0] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[1] - - if self.causal and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[2],) - if self.causal: - all_cross_attentions = all_cross_attentions + (layer_outputs[4],) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -class FlaxT5Stack(nn.Module): - config: T5Config - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.causal = self.config.causal - - self.block = FlaxT5BlockCollection( - self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.final_layer_norm = FlaxT5LayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - init_cache: bool = False, - ): - hidden_states = self.embed_tokens(input_ids) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - - outputs = self.block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - deterministic=deterministic, - init_cache=init_cache, - ) - - hidden_states = outputs[0] - - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - - # Add last layer - all_hidden_states = None - - if output_hidden_states: - all_hidden_states = outputs.hidden_states - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - if output_hidden_states: - return ( - hidden_states, - all_hidden_states, - ) + outputs[2:] - return (hidden_states,) + outputs[1:] - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -T5_ENCODE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you - should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -T5_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For training, `decoder_input_ids` should be provided. - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -T5_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you - should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - [What are input IDs?](../glossary#input-ids) - - To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` - is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - - To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 - Training](./t5#training). - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*): - Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at - the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class FlaxT5PreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = T5Config - base_model_prefix = "transformer" - module_class: nn.Module = None - - def __init__( - self, - config: T5Config, - input_shape: tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - gradient_checkpointing: bool = False, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def enable_gradient_checkpointing(self): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=True, - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - - attention_mask = jnp.ones_like(input_ids) - args = [input_ids, attention_mask] - if self.module_class not in [FlaxT5EncoderModule]: - decoder_input_ids = jnp.ones_like(input_ids) - decoder_attention_mask = jnp.ones_like(input_ids) - args.extend([decoder_input_ids, decoder_attention_mask]) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - *args, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: jnp.ndarray = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if decoder_input_ids is None: - raise ValueError( - "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed" - " here." - ) - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # prepare decoder inputs - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(T5_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=T5Config) - def encode( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") - >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_ids, attention_mask, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_ids, attention_mask, **kwargs) - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - @add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=T5Config) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration - >>> import jax.numpy as jnp - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") - >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxT5Attention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - -T5_START_DOCSTRING = r""" - The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text - Transformer](https://huggingface.co/papers/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan - Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a - text-to-text denoising generative setting. - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`T5Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - - -@add_start_docstrings( - "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.", - T5_START_DOCSTRING, -) -class FlaxT5Module(nn.Module): - config: T5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), - dtype=self.dtype, - ) - - encoder_config = copy.deepcopy(self.config) - encoder_config.causal = False - self.encoder = FlaxT5Stack( - encoder_config, - embed_tokens=self.shared, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - decoder_config = copy.deepcopy(self.config) - decoder_config.causal = True - decoder_config.num_layers = self.config.num_decoder_layers - self.decoder = FlaxT5Stack( - decoder_config, - embed_tokens=self.shared, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - def __call__( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - encoder_outputs=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - deterministic: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Encode if needed (training, first prediction pass) - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -class FlaxT5Model(FlaxT5PreTrainedModel): - module_class = FlaxT5Module - - -append_call_sample_docstring(FlaxT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) - -FLAX_T5_MODEL_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxT5Model - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") - >>> model = FlaxT5Model.from_pretrained("google-t5/t5-small") - - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="np" - ... ).input_ids - >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids - - >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. - >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. - >>> decoder_input_ids = model._shift_right(decoder_input_ids) - - >>> # forward pass - >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ``` -""" - - -overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING) -append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - - -@add_start_docstrings( - "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", - T5_START_DOCSTRING, -) -class FlaxT5EncoderModule(nn.Module): - config: T5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), - dtype=self.dtype, - ) - - encoder_config = copy.deepcopy(self.config) - encoder_config.is_decoder = False - encoder_config.is_encoder_decoder = False - encoder_config.causal = False - self.encoder = FlaxT5Stack( - encoder_config, - embed_tokens=self.shared, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - def __call__( - self, - input_ids=None, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict: bool = True, - deterministic: bool = True, - ): - # Encode if needed (training, first prediction pass) - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - return encoder_outputs - - -class FlaxT5EncoderModel(FlaxT5PreTrainedModel): - module_class = FlaxT5EncoderModule - - @add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - -@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) -class FlaxT5ForConditionalGenerationModule(nn.Module): - config: T5Config - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def setup(self): - self.model_dim = self.config.d_model - - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), - dtype=self.dtype, - ) - - encoder_config = copy.deepcopy(self.config) - encoder_config.causal = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = FlaxT5Stack( - encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - - decoder_config = copy.deepcopy(self.config) - decoder_config.causal = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = self.config.num_decoder_layers - self.decoder = FlaxT5Stack( - decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - kernel_init=jax.nn.initializers.normal(self.config.initializer_factor), - dtype=self.dtype, - ) - - def __call__( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - encoder_outputs=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - deterministic: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Encode - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = encoder_outputs[0] - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - sequence_output = decoder_outputs[0] - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) - - if self.config.tie_word_embeddings: - shared_embedding = self.shared.variables["params"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) - else: - lm_logits = self.lm_head(sequence_output) - - if not return_dict: - return (lm_logits,) + decoder_outputs[1:] + encoder_outputs - - return FlaxSeq2SeqLMOutput( - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel): - module_class = FlaxT5ForConditionalGenerationModule - - @add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=T5Config) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration - >>> import jax.numpy as jnp - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") - >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") - - >>> text = "summarize: My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxT5Attention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): - decoder_module = module._get_decoder_module() - decoder_outputs = decoder_module( - decoder_input_ids, - decoder_attention_mask, - **kwargs, - ) - - sequence_output = decoder_outputs[0] - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.config.d_model**-0.5) - - if self.config.tie_word_embeddings: - shared_embedding = module.shared.variables["params"]["embedding"] - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) - else: - lm_logits = module.lm_head(sequence_output) - - return lm_logits, decoder_outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - extended_attention_mask = jax.lax.dynamic_update_slice( - extended_attention_mask, decoder_attention_mask, (0, 0) - ) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - return model_kwargs - - -FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") - >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") - - >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np") - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs["input_ids"]).sequences - >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)) - ``` -""" - - -overwrite_call_docstring( - FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING -) -append_replace_return_docstrings( - FlaxT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC -) - - -__all__ = ["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"] diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py deleted file mode 100644 index 142a0f73115e..000000000000 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ /dev/null @@ -1,1676 +0,0 @@ -# coding=utf-8 -# Copyright 2020 T5 Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 T5 model.""" - -from __future__ import annotations - -import copy -import itertools -import math -import warnings - -import numpy as np -import tensorflow as tf -from tensorflow.compiler.tf2xla.python.xla import dynamic_slice - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPastAndCrossAttentions, - TFSeq2SeqLMOutput, - TFSeq2SeqModelOutput, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_t5 import T5Config - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "T5Config" - - -#################################################### -# TF 2.0 Models are constructed using Keras imperative API by sub-classing -# - keras.layers.Layer for the layers and -# - TFPreTrainedModel for the models (it-self a sub-class of keras.Model) -#################################################### - - -class TFT5LayerNorm(keras.layers.Layer): - def __init__(self, hidden_size, epsilon=1e-6, **kwargs): - """ - Construct a layernorm module in the T5 style No bias and no subtraction of mean. - """ - super().__init__(**kwargs) - self.variance_epsilon = epsilon - self.hidden_size = hidden_size - - def build(self, input_shape): - """Build shared word embedding layer""" - self.weight = self.add_weight("weight", shape=(self.hidden_size,), initializer="ones") - super().build(input_shape) - - def call(self, hidden_states): - variance = tf.math.reduce_mean(tf.math.square(hidden_states), axis=-1, keepdims=True) - hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states - - -class TFT5DenseActDense(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - wi_initializer = keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (config.d_model**-0.5) - ) - wo_initializer = keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5) - ) - self.wi = keras.layers.Dense( - config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer - ) # Update init weights as in flax - self.wo = keras.layers.Dense( - config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer - ) # Update init weights as in flax - self.dropout = keras.layers.Dropout(config.dropout_rate) - self.act = get_tf_activation(config.dense_act_fn) - self.config = config - - def call(self, hidden_states, training=False): - hidden_states = self.wi(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.wo(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "wi", None) is not None: - with tf.name_scope(self.wi.name): - self.wi.build([None, None, self.config.d_model]) - if getattr(self, "wo", None) is not None: - with tf.name_scope(self.wo.name): - self.wo.build([None, None, self.config.d_ff]) - - -class TFT5DenseGatedActDense(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - wi_initializer = keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (config.d_model**-0.5) - ) - wo_initializer = keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5) - ) - self.wi_0 = keras.layers.Dense( - config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer - ) # Update init weights as in flax - self.wi_1 = keras.layers.Dense( - config.d_ff, use_bias=False, name="wi_1", kernel_initializer=wi_initializer - ) # Update init weights as in flax - self.wo = keras.layers.Dense( - config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer - ) # Update init weights as in flax - self.dropout = keras.layers.Dropout(config.dropout_rate) - self.act = get_tf_activation(config.dense_act_fn) - self.config = config - - def call(self, hidden_states, training=False): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.wo(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "wi_0", None) is not None: - with tf.name_scope(self.wi_0.name): - self.wi_0.build([None, None, self.config.d_model]) - if getattr(self, "wi_1", None) is not None: - with tf.name_scope(self.wi_1.name): - self.wi_1.build([None, None, self.config.d_model]) - if getattr(self, "wo", None) is not None: - with tf.name_scope(self.wo.name): - self.wo.build([None, None, self.config.d_ff]) - - -class TFT5LayerFF(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - if config.is_gated_act: - self.DenseReluDense = TFT5DenseGatedActDense(config, name="DenseReluDense") - else: - self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense") - - self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") - self.dropout = keras.layers.Dropout(config.dropout_rate) - - def call(self, hidden_states, training=False): - normed_hidden_states = self.layer_norm(hidden_states) - dense_output = self.DenseReluDense(normed_hidden_states, training=training) - hidden_states = hidden_states + self.dropout(dense_output, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build(None) - if getattr(self, "DenseReluDense", None) is not None: - with tf.name_scope(self.DenseReluDense.name): - self.DenseReluDense.build(None) - - -class TFT5Attention(keras.layers.Layer): - NEW_ID = itertools.count() - - def __init__(self, config, has_relative_attention_bias=False, **kwargs): - super().__init__(**kwargs) - self.layer_id = next(TFT5Attention.NEW_ID) - self.is_decoder = config.is_decoder - self.use_cache = config.use_cache - self.has_relative_attention_bias = has_relative_attention_bias - self.output_attentions = config.output_attentions - - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_max_distance = config.relative_attention_max_distance - self.d_model = config.d_model - self.key_value_proj_dim = config.d_kv - self.n_heads = config.num_heads - self.inner_dim = self.n_heads * self.key_value_proj_dim - - # Mesh TensorFlow initialization to avoid scaling before softmax - q_initializer = keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) - ) - k_initializer = keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) - ) - v_initializer = keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) - ) - o_initializer = keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) - ) - self.relative_attention_bias_initializer = keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) - ) - - self.q = keras.layers.Dense( - self.inner_dim, use_bias=False, name="q", kernel_initializer=q_initializer - ) # Update init weights as in flax - self.k = keras.layers.Dense( - self.inner_dim, use_bias=False, name="k", kernel_initializer=k_initializer - ) # Update init weights as in flax - self.v = keras.layers.Dense( - self.inner_dim, use_bias=False, name="v", kernel_initializer=v_initializer - ) # Update init weights as in flax - self.o = keras.layers.Dense( - self.d_model, use_bias=False, name="o", kernel_initializer=o_initializer - ) # Update init weights as in flax - self.dropout = keras.layers.Dropout(config.dropout_rate) - - self.pruned_heads = set() - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if self.has_relative_attention_bias: - with tf.name_scope("relative_attention_bias"): - self.relative_attention_bias = self.add_weight( - name="embeddings", - shape=[self.relative_attention_num_buckets, self.n_heads], - initializer=self.relative_attention_bias_initializer, # Add initializer - ) - if getattr(self, "q", None) is not None: - with tf.name_scope(self.q.name): - self.q.build([None, None, self.d_model]) - if getattr(self, "k", None) is not None: - with tf.name_scope(self.k.name): - self.k.build([None, None, self.d_model]) - if getattr(self, "v", None) is not None: - with tf.name_scope(self.v.name): - self.v.build([None, None, self.d_model]) - if getattr(self, "o", None) is not None: - with tf.name_scope(self.o.name): - self.o.build([None, None, self.inner_dim]) - - def prune_heads(self, heads): - raise NotImplementedError - - @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - # n = -relative_position - if bidirectional: - num_buckets //= 2 - relative_buckets += ( - tf.cast(tf.math.greater(relative_position, 0), dtype=relative_position.dtype) * num_buckets - ) - relative_position = tf.math.abs(relative_position) - else: - relative_position = -tf.math.minimum(relative_position, 0) - # now n is in the range [0, inf) - max_exact = num_buckets // 2 - is_small = tf.math.less(relative_position, max_exact) - relative_position_if_large = max_exact + tf.cast( - tf.math.log(tf.cast(relative_position, tf.float32) / tf.cast(max_exact, tf.float32)) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact), - dtype=relative_position.dtype, - ) - relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1) - relative_buckets += tf.where(is_small, relative_position, relative_position_if_large) - return relative_buckets - - def compute_bias(self, query_length, key_length): - """Compute binned relative position bias""" - context_position = tf.range(query_length)[:, None] - memory_position = tf.range(key_length)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, - bidirectional=(not self.is_decoder), - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = tf.gather( - self.relative_attention_bias, relative_position_bucket - ) # shape (query_length, key_length, num_heads) - values = tf.expand_dims( - tf.transpose(values, [2, 0, 1]), axis=0 - ) # shape (1, num_heads, query_length, key_length) - return values - - def call( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - training=False, - output_attentions=False, - ): - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - # Input is (batch_size, query_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) - batch_size, seq_length = shape_list(hidden_states)[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - assert len(past_key_value) == 2, ( - f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" - ) - real_seq_length += shape_list(past_key_value[0])[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1] - - def shape(hidden_states): - """projection""" - return tf.transpose( - tf.reshape(hidden_states, (batch_size, -1, self.n_heads, self.key_value_proj_dim)), perm=(0, 2, 1, 3) - ) - - def unshape(hidden_states): - """compute context""" - return tf.reshape(tf.transpose(hidden_states, perm=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim)) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = tf.concat([past_key_value, hidden_states], axis=2) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, query_length, dim_per_head) - - # get key/value - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) - - # to cope with keras serialization - if self.is_decoder and use_cache: - present_key_value_state = (key_states, value_states) - else: - present_key_value_state = None - - scores = tf.einsum( - "bnqd,bnkd->bnqk", query_states, key_states - ) # (batch_size, n_heads, query_length, key_length) - - if position_bias is None: - if not self.has_relative_attention_bias: - position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length)) - else: - position_bias = self.compute_bias(real_seq_length, key_length) - - # if key and values are already calculated we want only the last query position bias - if past_key_value is not None: - if not self.has_relative_attention_bias: - position_bias = position_bias[:, :, -seq_length:, :] - else: - # we might have a padded past structure, in which case we want to fetch the position bias slice - # right after the most recently filled past index - most_recently_filled_past_index = tf.reduce_max(tf.where(past_key_value[0][0, 0, :, 0] != 0.0)) - position_bias = dynamic_slice( - position_bias, - (0, 0, most_recently_filled_past_index + 1, 0), - (1, self.n_heads, seq_length, real_seq_length), - ) - - if mask is not None: - position_bias = tf.cast(position_bias, dtype=mask.dtype) - position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length) - - scores += position_bias - weights = stable_softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length) - weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length) - - # Mask heads if we want to - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.n_heads], - message=( - f"Head mask for a single layer should be of size {(self.n_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights - - attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head) - - attn_output = self.o(unshape(attn_output)) - - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - - if output_attentions: - outputs = outputs + (weights,) - - return outputs - - -class TFT5LayerSelfAttention(keras.layers.Layer): - def __init__(self, config, has_relative_attention_bias=False, **kwargs): - super().__init__(**kwargs) - self.SelfAttention = TFT5Attention( - config, - has_relative_attention_bias=has_relative_attention_bias, - name="SelfAttention", - ) - self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") - self.dropout = keras.layers.Dropout(config.dropout_rate) - - def call( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - training=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - normed_hidden_states, - mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], training=training) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "SelfAttention", None) is not None: - with tf.name_scope(self.SelfAttention.name): - self.SelfAttention.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build(None) - - -class TFT5LayerCrossAttention(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.EncDecAttention = TFT5Attention( - config, - has_relative_attention_bias=False, - name="EncDecAttention", - ) - self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") - self.dropout = keras.layers.Dropout(config.dropout_rate) - - def call( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - query_length=None, - use_cache=False, - output_attentions=False, - training=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.EncDecAttention( - normed_hidden_states, - mask=attention_mask, - key_value_states=key_value_states, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - query_length=query_length, - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], training=training) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "EncDecAttention", None) is not None: - with tf.name_scope(self.EncDecAttention.name): - self.EncDecAttention.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build(None) - - -class TFT5Block(keras.layers.Layer): - def __init__(self, config, has_relative_attention_bias=False, **kwargs): - super().__init__(**kwargs) - self.is_decoder = config.is_decoder - self.layer = [] - self.layer.append( - TFT5LayerSelfAttention( - config, - has_relative_attention_bias=has_relative_attention_bias, - name="layer_._0", - ) - ) - if self.is_decoder: - self.layer.append( - TFT5LayerCrossAttention( - config, - name="layer_._1", - ) - ) - - self.layer.append(TFT5LayerFF(config, name=f"layer_._{len(self.layer)}")) - - def call( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - layer_head_mask=None, - encoder_layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - training=False, - ): - if past_key_value is not None: - assert self.is_decoder, "Only decoder can use `past_key_values`" - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention' if expected_num_past_key_values == 4 else ''}. " - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - hidden_states, present_key_value_state = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights - - if self.is_decoder and encoder_hidden_states is not None: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = shape_list(present_key_value_state[0])[2] - else: - query_length = None - - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - layer_head_mask=encoder_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - hidden_states = cross_attention_outputs[0] - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states, training=training) - outputs = (hidden_states,) - - # Add attentions if we output them - outputs = outputs + (present_key_value_state,) + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - for layer_module in self.layer: - if hasattr(layer_module, "name"): - with tf.name_scope(layer_module.name): - layer_module.build(None) - - -#################################################### -# The full model without a specific pretrained or finetuning head is -# provided as a keras.layers.Layer usually called "TFT5MainLayer" -#################################################### -@keras_serializable -class TFT5MainLayer(keras.layers.Layer): - config_class = T5Config - - def __init__(self, config, embed_tokens=None, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - self.use_cache = config.use_cache - - self.embed_tokens = embed_tokens - self.is_decoder = config.is_decoder - - self.config = config - self.num_hidden_layers = config.num_layers - - self.block = [ - TFT5Block(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}") - for i in range(config.num_layers) - ] - self.final_layer_norm = TFT5LayerNorm( - config.d_model, epsilon=config.layer_norm_epsilon, name="final_layer_norm" - ) - self.dropout = keras.layers.Dropout(config.dropout_rate) - - def _prune_heads(self, heads_to_prune): - raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=None, - head_mask=None, - encoder_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ) -> tuple: - if input_ids is not None and inputs_embeds is not None: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = shape_list(input_ids) - input_ids = tf.reshape(input_ids, (-1, input_shape[-1])) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") - - if inputs_embeds is None: - assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - - # required mask seq length can be calculated via length of past - mask_seq_length = ( - shape_list(past_key_values[0][0])[2] + seq_length if past_key_values is not None else seq_length - ) - - if attention_mask is None: - attention_mask = tf.fill((batch_size, mask_seq_length), 1) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = shape_list(encoder_hidden_states)[1] - encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1) - - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - attention_mask = tf.cast(attention_mask, dtype=inputs_embeds.dtype) - num_dims_attention_mask = len(shape_list(attention_mask)) - if num_dims_attention_mask == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif num_dims_attention_mask == 2: - # Provided a padding mask of dimensions [batch_size, mask_seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if self.is_decoder: - seq_ids = tf.range(mask_seq_length) - causal_mask = tf.less_equal( - tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), - seq_ids[None, :, None], - ) - causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - if past_key_values[0] is not None: - extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] - else: - extended_attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -1e9 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # extended_attention_mask = tf.math.equal(extended_attention_mask, - # tf.transpose(extended_attention_mask, perm=(-1, -2))) - - extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 - - if self.is_decoder and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) - num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) - if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, - # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) - - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 - else: - encoder_extended_attention_mask = None - - present_key_value_states = () if use_cache and self.is_decoder else None - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if (output_attentions and self.is_decoder) else None - position_bias = None - encoder_decoder_position_bias = None - - hidden_states = self.dropout(inputs_embeds, training=training) - - for idx, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, past_key_values, (self-attention weights), - # (self-attention position bias), (cross-attention position bias), (cross-attention weights), - position_bias = layer_outputs[2] - - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - - # append next layer key value states - if present_key_value_state is not None and use_cache and self.is_decoder: - present_key_value_states = present_key_value_states + (present_key_value_state,) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - outputs = (hidden_states,) - # need to check if is decoder here as well for special cases when using keras compile - if use_cache and self.is_decoder: - outputs = outputs + (present_key_value_states,) - if output_hidden_states: - outputs = outputs + (all_hidden_states,) - if output_attentions: - outputs = outputs + (all_attentions,) - if self.is_decoder: - outputs + (all_cross_attentions,) - return outputs # last-layer hidden state, (past_key_values), (all hidden states), (all attentions), (all_cross_attentions) - - if self.is_decoder: - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_value_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - else: - return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build(None) - if getattr(self, "block", None) is not None: - for layer in self.block: - with tf.name_scope(layer.name): - layer.build(None) - - -#################################################### -# TFT5PreTrainedModel is a sub-class of keras.Model -# which take care of loading and saving pretrained weights -# and various common utilities. -# Here you just need to specify a few (self-explanatory) -# pointers for your model. -#################################################### -class TFT5PreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = T5Config - base_model_prefix = "transformer" - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"decoder\Wblock[\W_0]+layer[\W_1]+EncDecAttention\Wrelative_attention_bias"] - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, value): - self.shared = value - self.encoder.embed_tokens = self.shared - if hasattr(self, "decoder"): - self.decoder.embed_tokens = self.shared - - def _shift_right(self, input_ids): - decoder_start_token_id = self.config.decoder_start_token_id - pad_token_id = self.config.pad_token_id - - assert decoder_start_token_id is not None, ( - "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the" - " pad_token_id. See T5 docs for more information" - ) - - start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) - start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - - assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal( - shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype) - ) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -T5_START_DOCSTRING = r""" - - The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text - Transformer](https://huggingface.co/papers/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan - Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a - text-to-text denoising generative setting. - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`T5Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -T5_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you - should be able to pad the inputs on the right or the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - - To know more on how to prepare `inputs` for pretraining take a look at [T5 Training](./t5#training). - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Provide for sequence to sequence training. T5 uses the `pad_token_id` as the starting token for - `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last `decoder_input_ids` - have to be input (see `past_key_values`). - - To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 - Training](./t5#training). - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tuple(tuple(tf.FloatTensor)`, *optional*): - Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at - the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - decoder_inputs_embeds (`tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be - input (see `past_key_values`). This is useful if you want more control over how to convert - `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - - If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value - of `inputs_embeds`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - -T5_ENCODER_INPUTS_DOCSTRING = r""" - Args: - inputs (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you - should be able to pad the inputs on the right or the left. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - To know more on how to prepare `inputs` for pre-training take a look at [T5 Training](./t5#training). - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - -_HEAD_MASK_WARNING_MSG = """ -The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, -`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. -If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers, -num_heads))`. -""" - - -@add_start_docstrings( - "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.", - T5_START_DOCSTRING, -) -class TFT5Model(TFT5PreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.shared = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.d_model, - embeddings_initializer=keras.initializers.TruncatedNormal(self.config.initializer_factor), - name="shared", - ) - # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) - self.shared.load_weight_prefix = "shared" - - encoder_config = copy.deepcopy(config) - encoder_config.use_cache = False - self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder") - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.num_layers = config.num_decoder_layers - self.decoder = TFT5MainLayer(decoder_config, self.shared, name="decoder") - - def get_encoder(self): - return self.encoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFSeq2SeqModelOutput: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TFT5Model - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") - >>> model = TFT5Model.from_pretrained("google-t5/t5-small") - - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="tf" - ... ).input_ids # Batch size 1 - >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1 - - >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. - >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. - >>> decoder_input_ids = model._shift_right(decoder_input_ids) - - >>> # forward pass - >>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids, - attention_mask=attention_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - past_key_values=None, - use_cache=False, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - hidden_states = encoder_outputs[0] - - # Decode - decoder_outputs = self.decoder( - decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - inputs_embeds=decoder_inputs_embeds, - head_mask=decoder_head_mask, - encoder_head_mask=head_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - past = decoder_outputs[1] if use_cache else None - - if not return_dict: - if past_key_values is not None: - decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] - return decoder_outputs + encoder_outputs - - return TFSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=past, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - # The shared/tied weights expect to be in the model base namespace - # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than - # the current one. - with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): - self.shared.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) -class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.model_dim = config.d_model - self.shared = keras.layers.Embedding( - config.vocab_size, - config.d_model, - name="shared", - embeddings_initializer=get_initializer(self.config.initializer_factor), - ) - # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) - self.shared.load_weight_prefix = "shared" - - encoder_config = copy.deepcopy(config) - encoder_config.use_cache = False - self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder") - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.num_layers = config.num_decoder_layers - self.decoder = TFT5MainLayer(decoder_config, self.shared, name="decoder") - - if not config.tie_word_embeddings: - lm_head_initializer = keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor) - self.lm_head = keras.layers.Dense( - config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer - ) # Update init weights as in flax - self.config = config - - def get_output_embeddings(self): - if self.config.tie_word_embeddings: - return self.get_input_embeddings() - else: - # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) - # value has a shape (num_tokens, dim) then needs to be transposed - return tf.transpose(self.lm_head.kernel) - - def set_output_embeddings(self, value): - if self.config.tie_word_embeddings: - self.set_input_embeddings(value) - else: - lm_head_initializer = keras.initializers.RandomNormal(mean=0, stddev=self.config.initializer_factor) - self.lm_head = keras.layers.Dense( - shape_list(value)[0], use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer - ) # Update init weights as in flax - # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) - # value has a shape (num_tokens, dim) then needs to be transposed - transposed_value = tf.transpose(value) - self.lm_head.kernel = transposed_value - - def get_encoder(self): - return self.encoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - labels: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFSeq2SeqLMOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TFT5ForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") - >>> model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") - - >>> # training - >>> inputs = tokenizer("The walks in park", return_tensors="tf").input_ids - >>> labels = tokenizer(" cute dog the ", return_tensors="tf").input_ids - >>> outputs = model(inputs, labels=labels) - >>> loss = outputs.loss - >>> logits = outputs.logits - - >>> # inference - >>> inputs = tokenizer( - ... "summarize: studies have shown that owning a dog is good for you", return_tensors="tf" - ... ).input_ids # Batch size 1 - >>> outputs = model.generate(inputs) - >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) - >>> # studies have shown that owning a dog is good for you - ```""" - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - hidden_states = encoder_outputs[0] - - if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: - # get decoder inputs from shifting lm labels to the right - decoder_input_ids = self._shift_right(labels) - - # Decode - decoder_outputs = self.decoder( - decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - inputs_embeds=decoder_inputs_embeds, - head_mask=decoder_head_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = decoder_outputs[0] - - # T5v1.1 does not tie output word embeddings and thus does not require downscaling - if self.config.tie_word_embeddings: - sequence_output = sequence_output * (self.model_dim**-0.5) - logits = tf.matmul(sequence_output, self.shared.weights, transpose_b=True) - else: - logits = self.lm_head(sequence_output) - - logits = tf.cast(logits, tf.float32) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - past = decoder_outputs[1] if use_cache else None - if not return_dict: - if past_key_values is not None: - decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] - output = (logits,) + decoder_outputs[1:] + encoder_outputs - return ((loss,) + output) if loss is not None else output - - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif isinstance(encoder_outputs, tuple): - last_hidden_state = encoder_outputs[0] - hidden_states = None - attentions = None - idx = 0 - if output_hidden_states: - idx += 1 - hidden_states = encoder_outputs[idx] - if output_attentions: - idx += 1 - attentions = encoder_outputs[idx] - - encoder_outputs = TFBaseModelOutput( - last_hidden_state=last_hidden_state, - hidden_states=hidden_states, - attentions=attentions, - ) - - return TFSeq2SeqLMOutput( - loss=loss, - logits=logits, - past_key_values=past, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def serving_output(self, output): - pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return { - "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "use_cache": use_cache, - } - - def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): - return self._shift_right(labels) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - # The shared/tied weights expect to be in the model base namespace - # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than - # the current one. - with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): - self.shared.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build([None, None, self.config.d_model]) - - -@add_start_docstrings( - "The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.", - T5_START_DOCSTRING, -) -class TFT5EncoderModel(TFT5PreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.shared = keras.layers.Embedding( - config.vocab_size, - config.d_model, - name="shared", - embeddings_initializer=get_initializer(self.config.initializer_factor), - ) - # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) - self.shared.load_weight_prefix = "shared" - - encoder_config = copy.deepcopy(config) - encoder_config.use_cache = False - self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder") - - def get_encoder(self): - return self.encoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutput: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TFT5EncoderModel - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") - >>> model = TFT5EncoderModel.from_pretrained("google-t5/t5-small") - - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="tf" - ... ).input_ids # Batch size 1 - >>> outputs = model(input_ids) - ```""" - - encoder_outputs = self.encoder( - input_ids, - attention_mask=attention_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - past_key_values=None, - use_cache=False, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return encoder_outputs - - return TFBaseModelOutput( - last_hidden_state=encoder_outputs.last_hidden_state, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - # The shared/tied weights expect to be in the model base namespace - # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than - # the current one. - with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): - self.shared.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - - -__all__ = ["TFT5EncoderModel", "TFT5ForConditionalGeneration", "TFT5Model", "TFT5PreTrainedModel"] diff --git a/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py deleted file mode 100644 index 34bf77cccd6b..000000000000 --- a/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,137 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert TAPAS checkpoint.""" - -import argparse - -from transformers import ( - TapasConfig, - TapasForMaskedLM, - TapasForQuestionAnswering, - TapasForSequenceClassification, - TapasModel, - TapasTokenizer, - load_tf_weights_in_tapas, -) -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_tf_checkpoint_to_pytorch( - task, reset_position_index_per_cell, tf_checkpoint_path, tapas_config_file, pytorch_dump_path -): - # Initialise PyTorch model. - # If you want to convert a checkpoint that uses absolute position embeddings, make sure to set reset_position_index_per_cell of - # TapasConfig to False. - - # initialize configuration from json file - config = TapasConfig.from_json_file(tapas_config_file) - # set absolute/relative position embeddings parameter - config.reset_position_index_per_cell = reset_position_index_per_cell - - # set remaining parameters of TapasConfig as well as the model based on the task - if task == "SQA": - model = TapasForQuestionAnswering(config=config) - elif task == "WTQ": - # run_task_main.py hparams - config.num_aggregation_labels = 4 - config.use_answer_as_supervision = True - # hparam_utils.py hparams - config.answer_loss_cutoff = 0.664694 - config.cell_selection_preference = 0.207951 - config.huber_loss_delta = 0.121194 - config.init_cell_selection_weights_to_zero = True - config.select_one_column = True - config.allow_empty_column_selection = False - config.temperature = 0.0352513 - - model = TapasForQuestionAnswering(config=config) - elif task == "WIKISQL_SUPERVISED": - # run_task_main.py hparams - config.num_aggregation_labels = 4 - config.use_answer_as_supervision = False - # hparam_utils.py hparams - config.answer_loss_cutoff = 36.4519 - config.cell_selection_preference = 0.903421 - config.huber_loss_delta = 222.088 - config.init_cell_selection_weights_to_zero = True - config.select_one_column = True - config.allow_empty_column_selection = True - config.temperature = 0.763141 - - model = TapasForQuestionAnswering(config=config) - elif task == "TABFACT": - model = TapasForSequenceClassification(config=config) - elif task == "MLM": - model = TapasForMaskedLM(config=config) - elif task == "INTERMEDIATE_PRETRAINING": - model = TapasModel(config=config) - else: - raise ValueError(f"Task {task} not supported.") - - print(f"Building PyTorch model from configuration: {config}") - # Load weights from tf checkpoint - load_tf_weights_in_tapas(model, config, tf_checkpoint_path) - - # Save pytorch-model (weights and configuration) - print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) - - # Save tokenizer files - print(f"Save tokenizer files to {pytorch_dump_path}") - tokenizer = TapasTokenizer(vocab_file=tf_checkpoint_path[:-10] + "vocab.txt", model_max_length=512) - tokenizer.save_pretrained(pytorch_dump_path) - - print("Used relative position embeddings:", model.config.reset_position_index_per_cell) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--task", default="SQA", type=str, help="Model task for which to convert a checkpoint. Defaults to SQA." - ) - parser.add_argument( - "--reset_position_index_per_cell", - default=False, - action="store_true", - help="Whether to use relative position embeddings or not. Defaults to True.", - ) - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--tapas_config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained TAPAS model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_tf_checkpoint_to_pytorch( - args.task, - args.reset_position_index_per_cell, - args.tf_checkpoint_path, - args.tapas_config_file, - args.pytorch_dump_path, - ) diff --git a/src/transformers/models/tapas/modeling_tf_tapas.py b/src/transformers/models/tapas/modeling_tf_tapas.py deleted file mode 100644 index 624df1fba176..000000000000 --- a/src/transformers/models/tapas/modeling_tf_tapas.py +++ /dev/null @@ -1,2461 +0,0 @@ -# coding=utf-8 -# Copyright 2021 Google Research and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 TAPAS model.""" - -from __future__ import annotations - -import enum -import math -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFBaseModelOutputWithPooling, - TFMaskedLMOutput, - TFSequenceClassifierOutput, -) -from ...modeling_tf_utils import ( - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_tensorflow_probability_available, - logging, - replace_return_docstrings, -) -from .configuration_tapas import TapasConfig - - -logger = logging.get_logger(__name__) - -# soft dependency -if is_tensorflow_probability_available(): - try: - import tensorflow_probability as tfp - - # On the first call, check whether a compatible version of TensorFlow is installed - # TensorFlow Probability depends on a recent stable release of TensorFlow - n = tfp.distributions.Normal(loc=0.0, scale=1.0) - except ImportError: - logger.error( - "TAPAS models are not usable since `tensorflow_probability` can't be loaded. " - "It seems you have `tensorflow_probability` installed with the wrong tensorflow version. " - "Please try to reinstall it following the instructions here: https://github.com/tensorflow/probability." - ) -else: - try: - import tensorflow_probability as tfp - - # On the first call, check whether a compatible version of TensorFlow is installed - # TensorFlow Probability depends on a recent stable release of TensorFlow - _ = tfp.distributions.Normal(loc=0.0, scale=1.0) - except ImportError: - pass - -_CONFIG_FOR_DOC = "TapasConfig" -_CHECKPOINT_FOR_DOC = "google/tapas-base" - - -EPSILON_ZERO_DIVISION = 1e-10 -CLOSE_ENOUGH_TO_LOG_ZERO = -10000.0 - - -@dataclass -class TFTableQuestionAnsweringOutput(ModelOutput): - """ - Output type of [`TFTapasForQuestionAnswering`]. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` (and possibly `answer`, `aggregation_labels`, `numeric_values` and `numeric_values_scale` are provided)): - Total loss as the sum of the hierarchical cell selection log-likelihood loss and (optionally) the - semi-supervised regression loss and (optionally) supervised loss for aggregations. - logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Prediction scores of the cell selection head, for every token. - logits_aggregation (`tf.Tensor`, *optional*, of shape `(batch_size, num_aggregation_labels)`): - Prediction scores of the aggregation head, for every aggregation operator. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus - the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - logits_aggregation: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -class TFTapasEmbeddings(keras.layers.Layer): - """ - Construct the embeddings from word, position and token_type embeddings. Same as BertEmbeddings but with a number of - additional token type embeddings to encode tabular structure. - """ - - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.number_of_token_type_embeddings = len(config.type_vocab_sizes) - self.reset_position_index_per_cell = config.reset_position_index_per_cell - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - for i, type_vocab_size in enumerate(self.config.type_vocab_sizes): - with tf.name_scope(f"token_type_embeddings_{i}"): - setattr( - self, - f"token_type_embeddings_{i}", - self.add_weight( - name="embeddings", - shape=[type_vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - def call( - self, - input_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - training: bool = False, - ) -> tf.Tensor: - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - if input_ids is not None: - input_shape = shape_list(input_ids) - else: - input_shape = shape_list(inputs_embeds)[:-1] - - seq_length = input_shape[1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape + [self.number_of_token_type_embeddings], value=0) - - if position_ids is None: - # create absolute position embeddings - position_ids = tf.expand_dims(tf.range(start=0, limit=seq_length), axis=0) - position_ids = tf.broadcast_to(position_ids, shape=input_shape) - # when self.config.reset_position_index_per_cell is set to True, create relative position embeddings - if self.reset_position_index_per_cell: - # shape (batch_size, seq_len) - col_index = IndexMap(token_type_ids[:, :, 1], self.config.type_vocab_sizes[1], batch_dims=1) - # shape (batch_size, seq_len) - row_index = IndexMap(token_type_ids[:, :, 2], self.config.type_vocab_sizes[2], batch_dims=1) - # shape (batch_size, seq_len) - full_index = ProductIndexMap(col_index, row_index) - # shape (max_rows * max_columns,). First absolute position for every cell - first_position_per_segment = reduce_min(position_ids, full_index)[0] - # ? shape (batch_size, seq_len). First absolute position of the cell for every token - first_position = gather(first_position_per_segment, full_index) - # shape (1, seq_len) - position = tf.expand_dims(tf.range(start=0, limit=seq_length), axis=0) - position_ids = tf.math.minimum(self.max_position_embeddings - 1, position - first_position) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - position_embeddings = tf.gather(self.position_embeddings, indices=position_ids) - - final_embeddings = inputs_embeds + position_embeddings - - for i in range(self.number_of_token_type_embeddings): - name = f"token_type_embeddings_{i}" - final_embeddings += tf.gather(params=getattr(self, name), indices=token_type_ids[:, :, i]) - - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Tapas -class TFTapasSelfAttention(keras.layers.Layer): - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - - self.is_decoder = config.is_decoder - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concat([past_key_value[0], key_layer], axis=2) - value_layer = tf.concat([past_key_value[1], value_layer], axis=2) - else: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFTapasModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Tapas -class TFTapasSelfOutput(keras.layers.Layer): - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Tapas -class TFTapasAttention(keras.layers.Layer): - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFTapasSelfAttention(config, name="self") - self.dense_output = TFTapasSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - # add attentions (possibly with past_key_value) if we output them - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Tapas -class TFTapasIntermediate(keras.layers.Layer): - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Tapas -class TFTapasOutput(keras.layers.Layer): - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Tapas -class TFTapasLayer(keras.layers.Layer): - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFTapasAttention(config, name="attention") - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TFTapasAttention(config, name="crossattention") - self.intermediate = TFTapasIntermediate(config, name="intermediate") - self.bert_output = TFTapasOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_value: tuple[tf.Tensor] | None, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=self_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - input_tensor=attention_output, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + outputs # add attentions if we output them - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - if getattr(self, "crossattention", None) is not None: - with tf.name_scope(self.crossattention.name): - self.crossattention.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Tapas -class TFTapasEncoder(keras.layers.Layer): - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layer = [TFTapasLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_values: tuple[tuple[tf.Tensor]] | None, - use_cache: bool | None, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - past_key_value = past_key_values[i] if past_key_values is not None else None - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - if self.config.add_cross_attention and encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None - ) - - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Tapas -class TFTapasPooler(keras.layers.Layer): - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->Tapas -class TFTapasPredictionHeadTransform(keras.layers.Layer): - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - name="dense", - ) - - if isinstance(config.hidden_act, str): - self.transform_act_fn = get_tf_activation(config.hidden_act) - else: - self.transform_act_fn = config.hidden_act - - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(inputs=hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead with Bert->Tapas -class TFTapasLMPredictionHead(keras.layers.Layer): - def __init__(self, config: TapasConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - - self.transform = TFTapasPredictionHeadTransform(config, name="transform") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.input_embeddings = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "transform", None) is not None: - with tf.name_scope(self.transform.name): - self.transform.build(None) - - def get_output_embeddings(self) -> keras.layers.Layer: - return self.input_embeddings - - def set_output_embeddings(self, value: tf.Variable): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self) -> dict[str, tf.Variable]: - return {"bias": self.bias} - - def set_bias(self, value: tf.Variable): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.transform(hidden_states=hidden_states) - seq_length = shape_list(hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->Tapas -class TFTapasMLMHead(keras.layers.Layer): - def __init__(self, config: TapasConfig, input_embeddings: keras.layers.Layer, **kwargs): - super().__init__(**kwargs) - - self.predictions = TFTapasLMPredictionHead(config, input_embeddings, name="predictions") - - def call(self, sequence_output: tf.Tensor) -> tf.Tensor: - prediction_scores = self.predictions(hidden_states=sequence_output) - - return prediction_scores - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "predictions", None) is not None: - with tf.name_scope(self.predictions.name): - self.predictions.build(None) - - -@keras_serializable -class TFTapasMainLayer(keras.layers.Layer): - config_class = TapasConfig - - def __init__(self, config: TapasConfig, add_pooling_layer: bool = True, **kwargs): - super().__init__(**kwargs) - - self.config = config - - self.embeddings = TFTapasEmbeddings(config, name="embeddings") - self.encoder = TFTapasEncoder(config, name="encoder") - self.pooler = TFTapasPooler(config, name="pooler") if add_pooling_layer else None - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if attention_mask is None: - attention_mask = tf.fill(dims=input_shape, value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape + [len(self.config.type_vocab_sizes)], value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - training=training, - ) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -class TFTapasPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = TapasConfig - base_model_prefix = "tapas" - - @property - def input_signature(self): - return { - "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), - "attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"), - "token_type_ids": tf.TensorSpec((None, None, 7), tf.int32, name="token_type_ids"), - } - - -TAPAS_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`TapasConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -TAPAS_INPUTS_DOCSTRING = r""" - Args: - input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0}, 7)`, *optional*): - Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this - class for more info. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. If - `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be - used. Selected in the range `[0, config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare Tapas Model transformer outputting raw hidden-states without any specific head on top.", - TAPAS_START_DOCSTRING, -) -class TFTapasModel(TFTapasPreTrainedModel): - def __init__(self, config: TapasConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.tapas = TFTapasMainLayer(config, name="tapas") - - @unpack_inputs - @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TapasModel - >>> import pandas as pd - - >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base") - >>> model = TapasModel.from_pretrained("google/tapas-base") - - >>> data = { - ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], - ... "Age": ["56", "45", "59"], - ... "Number of movies": ["87", "53", "69"], - ... } - >>> table = pd.DataFrame.from_dict(data) - >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"] - - >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="tf") - >>> outputs = model(**inputs) - - >>> last_hidden_states = outputs.last_hidden_state - ```""" - outputs = self.tapas( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "tapas", None) is not None: - with tf.name_scope(self.tapas.name): - self.tapas.build(None) - - -@add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING) -class TFTapasForMaskedLM(TFTapasPreTrainedModel, TFMaskedLanguageModelingLoss): - def __init__(self, config: TapasConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if config.is_decoder: - logger.warning( - "If you want to use `TFTapasForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.tapas = TFTapasMainLayer(config, add_pooling_layer=False, name="tapas") - self.lm_head = TFTapasMLMHead(config, input_embeddings=self.tapas.embeddings, name="cls") - - def get_lm_head(self) -> keras.layers.Layer: - return self.lm_head.predictions - - @unpack_inputs - @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TapasForMaskedLM - >>> import pandas as pd - - >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base") - >>> model = TapasForMaskedLM.from_pretrained("google/tapas-base") - - >>> data = { - ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], - ... "Age": ["56", "45", "59"], - ... "Number of movies": ["87", "53", "69"], - ... } - >>> table = pd.DataFrame.from_dict(data) - - >>> inputs = tokenizer( - ... table=table, queries="How many [MASK] has George [MASK] played in?", return_tensors="tf" - ... ) - >>> labels = tokenizer( - ... table=table, queries="How many movies has George Clooney played in?", return_tensors="tf" - ... )["input_ids"] - - >>> outputs = model(**inputs, labels=labels) - >>> logits = outputs.logits - ```""" - outputs = self.tapas( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "tapas", None) is not None: - with tf.name_scope(self.tapas.name): - self.tapas.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -class TFTapasComputeTokenLogits(keras.layers.Layer): - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - - self.temperature = config.temperature - # cell selection heads - with tf.name_scope("output"): - self.output_weights = self.add_weight( - name="output_weights", - shape=(config.hidden_size,), - dtype=tf.float32, - trainable=True, - initializer=tf.zeros_initializer() - if config.init_cell_selection_weights_to_zero - else keras.initializers.TruncatedNormal(stddev=config.initializer_range), - ) - self.output_bias = self.add_weight( - name="output_bias", shape=(), trainable=True, initializer=tf.zeros_initializer() - ) - - def call(self, sequence_output: tf.Tensor) -> tf.Tensor: - """ - Computes logits per token - - Args: - sequence_output (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the - model. - - Returns: - logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): Logits per token. - """ - logits = (tf.einsum("bsj,j->bs", sequence_output, self.output_weights) + self.output_bias) / self.temperature - return logits - - -class TFTapasComputeColumnLogits(keras.layers.Layer): - def __init__(self, config: TapasConfig, **kwargs): - super().__init__(**kwargs) - - with tf.name_scope("column_output"): - self.column_output_weights = self.add_weight( - name="column_output_weights", - shape=[config.hidden_size], - dtype=tf.float32, - trainable=True, - initializer=tf.zeros_initializer() - if config.init_cell_selection_weights_to_zero - else keras.initializers.TruncatedNormal(stddev=config.initializer_range), - ) - self.column_output_bias = self.add_weight( - name="column_output_bias", shape=(), trainable=True, initializer=tf.zeros_initializer() - ) - - def call(self, sequence_output, cell_index, cell_mask, allow_empty_column_selection) -> tf.Tensor: - """ - Computes the column logits. - - Args: - sequence_output (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the - model. - cell_index (`ProductIndexMap`): - Index that groups tokens into cells. - cell_mask (`tf.Tensor` of shape `(batch_size, max_num_rows * max_num_cols)`): - Mask for cells that exist in the table (i.e. that are not padding). - allow_empty_column_selection (`bool`): - Whether to allow not to select any column - - Returns: - column_logits (`tf.Tensor`of shape `(batch_size, max_num_cols)`): Tensor containing the column logits for - every example in the batch. - """ - - # First, compute the token logits (batch_size, seq_len) - without temperature - token_logits = tf.einsum("bsj,j->bs", sequence_output, self.column_output_weights) + self.column_output_bias - - # Next, average the logits per cell (batch_size, max_num_cols*max_num_rows) - cell_logits, cell_logits_index = reduce_mean(token_logits, cell_index) - - # Finally, average the logits per column (batch_size, max_num_cols) - column_index = cell_index.project_inner(cell_logits_index) - column_logits, out_index = reduce_sum(cell_logits * cell_mask, column_index) - - cell_count, _ = reduce_sum(cell_mask, column_index) - column_logits /= cell_count + EPSILON_ZERO_DIVISION - - # Mask columns that do not appear in the example. - is_padding = tf.logical_and(cell_count < 0.5, tf.not_equal(out_index.indices, 0)) - column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast(is_padding, tf.float32) - - if not allow_empty_column_selection: - column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast(tf.equal(out_index.indices, 0), tf.float32) - - return column_logits - - -@add_start_docstrings( - """ - Tapas Model with a cell selection head and optional aggregation head on top for question-answering tasks on tables - (linear layers on top of the hidden-states output to compute `logits` and optional `logits_aggregation`), e.g. for - SQA, WTQ or WikiSQL-supervised tasks. - """, - TAPAS_START_DOCSTRING, -) -class TFTapasForQuestionAnswering(TFTapasPreTrainedModel): - def __init__(self, config: TapasConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - # base model - self.tapas = TFTapasMainLayer(config, name="tapas") - - # dropout - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - - self.compute_token_logits = TFTapasComputeTokenLogits(config, name="compute_token_logits") - - self.compute_column_logits = TFTapasComputeColumnLogits(config, name="compute_column_logits") - - if config.num_aggregation_labels > 0: - self.aggregation_classifier = keras.layers.Dense( - config.num_aggregation_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="aggregation_classifier", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFTableQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - table_mask: np.ndarray | tf.Tensor | None = None, - aggregation_labels: np.ndarray | tf.Tensor | None = None, - float_answer: np.ndarray | tf.Tensor | None = None, - numeric_values: np.ndarray | tf.Tensor | None = None, - numeric_values_scale: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTableQuestionAnsweringOutput | tuple[tf.Tensor]: - r""" - table_mask (`tf.Tensor` of shape `(batch_size, seq_length)`, *optional*): - Mask for the table. Indicates which tokens belong to the table (1). Question tokens, table headers and - padding are 0. - labels (`tf.Tensor` of shape `(batch_size, seq_length)`, *optional*): - Labels per token for computing the hierarchical cell selection loss. This encodes the positions of the - answer appearing in the table. Can be obtained using [`AutoTokenizer`]. - - - 1 for tokens that are **part of the answer**, - - 0 for tokens that are **not part of the answer**. - - aggregation_labels (`tf.Tensor` of shape `(batch_size, )`, *optional*): - Aggregation function index for every example in the batch for computing the aggregation loss. Indices - should be in `[0, ..., config.num_aggregation_labels - 1]`. Only required in case of strong supervision for - aggregation (WikiSQL-supervised). - float_answer (`tf.Tensor` of shape `(batch_size, )`, *optional*): - Float answer for every example in the batch. Set to *float('nan')* for cell selection questions. Only - required in case of weak supervision (WTQ) to calculate the aggregate mask and regression loss. - numeric_values (`tf.Tensor` of shape `(batch_size, seq_length)`, *optional*): - Numeric values of every token, NaN for tokens which are not numeric values. Can be obtained using - [`AutoTokenizer`]. Only required in case of weak supervision for aggregation (WTQ) to calculate the - regression loss. - numeric_values_scale (`tf.Tensor` of shape `(batch_size, seq_length)`, *optional*): - Scale of the numeric values of every token. Can be obtained using [`AutoTokenizer`]. Only required in case - of weak supervision for aggregation (WTQ) to calculate the regression loss. - - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TapasForQuestionAnswering - >>> import pandas as pd - - >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-wtq") - >>> model = TapasForQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq") - - >>> data = { - ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], - ... "Age": ["56", "45", "59"], - ... "Number of movies": ["87", "53", "69"], - ... } - >>> table = pd.DataFrame.from_dict(data) - >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"] - - >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="tf") - >>> outputs = model(**inputs) - - >>> logits = outputs.logits - >>> logits_aggregation = outputs.logits_aggregation - ```""" - - outputs = self.tapas( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - pooled_output = outputs[1] - - sequence_output = self.dropout(sequence_output) - - if input_ids is not None: - input_shape = shape_list(input_ids) - else: - input_shape = shape_list(inputs_embeds)[:-1] - - # Construct indices for the table. - if token_type_ids is None: - token_type_ids = tf.fill(input_shape + [len(self.config.type_vocab_sizes)], 0) - - token_types = [ - "segment_ids", - "column_ids", - "row_ids", - "prev_labels", - "column_ranks", - "inv_column_ranks", - "numeric_relations", - ] - - row_ids = token_type_ids[:, :, token_types.index("row_ids")] - column_ids = token_type_ids[:, :, token_types.index("column_ids")] - - # Construct indices for the table. - row_index = IndexMap( - indices=tf.minimum(tf.cast(row_ids, tf.int32), self.config.max_num_rows - 1), - num_segments=self.config.max_num_rows, - batch_dims=1, - ) - col_index = IndexMap( - indices=tf.minimum(tf.cast(column_ids, tf.int32), self.config.max_num_columns - 1), - num_segments=self.config.max_num_columns, - batch_dims=1, - ) - cell_index = ProductIndexMap(row_index, col_index) - - # Masks. - input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:-1] - if attention_mask is None: - attention_mask = tf.ones(input_shape) - # Table cells only, without question tokens and table headers. - if table_mask is None: - table_mask = tf.where(row_ids > 0, tf.ones_like(row_ids), tf.zeros_like(row_ids)) - # [batch_size, seq_length] - input_mask_float = tf.cast(attention_mask, tf.float32) - table_mask_float = tf.cast(table_mask, tf.float32) - - # Mask for cells that exist in the table (i.e. that are not padding). - cell_mask, _ = reduce_mean(input_mask_float, cell_index) - - # Compute logits per token. These are used to select individual cells. - logits = self.compute_token_logits(sequence_output) - - # Compute logits per column. These are used to select a column. - column_logits = None - if self.config.select_one_column: - column_logits = self.compute_column_logits( - sequence_output, cell_index, cell_mask, self.config.allow_empty_column_selection - ) - - # Aggregate logits. - logits_aggregation = None - if self.config.num_aggregation_labels > 0: - logits_aggregation = self.aggregation_classifier(pooled_output) - - # Total loss calculation - total_loss = tf.zeros(shape=(1,), dtype=tf.float32) - calculate_loss = False - if labels is not None: - calculate_loss = True - is_supervised = not self.config.num_aggregation_labels > 0 or not self.config.use_answer_as_supervision - - # Semi-supervised cell selection in case of no aggregation: - # If the answer (the denotation) appears directly in the table we might - # select the answer without applying any aggregation function. There are - # some ambiguous cases, see utils._calculate_aggregate_mask for more info. - # `aggregate_mask` is 1 for examples where we chose to aggregate and 0 - # for examples where we chose to select the answer directly. - # `labels` encodes the positions of the answer appearing in the table. - if is_supervised: - aggregate_mask = None - else: - if float_answer is not None: - assert shape_list(labels)[0] == shape_list(float_answer)[0], ( - "Make sure the answers are a FloatTensor of shape (batch_size,)" - ) - # [batch_size] - aggregate_mask = _calculate_aggregate_mask( - float_answer, - pooled_output, - self.config.cell_selection_preference, - labels, - self.aggregation_classifier, - ) - else: - aggregate_mask = None - raise ValueError("You have to specify float answers in order to calculate the aggregate mask") - - # Cell selection log-likelihood - if self.config.average_logits_per_cell: - logits_per_cell, _ = reduce_mean(logits, cell_index) - logits = gather(logits_per_cell, cell_index) - dist_per_token = tfp.distributions.Bernoulli(logits=logits) - - # Compute cell selection loss per example. - selection_loss_per_example = None - if not self.config.select_one_column: - weight = tf.where( - labels == 0, - tf.ones_like(labels, dtype=tf.float32), - self.config.positive_label_weight * tf.ones_like(labels, dtype=tf.float32), - ) - selection_loss_per_token = -dist_per_token.log_prob(labels) * weight - selection_loss_per_example = tf.reduce_sum(selection_loss_per_token * input_mask_float, axis=1) / ( - tf.reduce_sum(input_mask_float, axis=1) + EPSILON_ZERO_DIVISION - ) - else: - selection_loss_per_example, logits = _single_column_cell_selection_loss( - logits, column_logits, labels, cell_index, col_index, cell_mask - ) - dist_per_token = tfp.distributions.Bernoulli(logits=logits) - - # Supervised cell selection - if self.config.disable_per_token_loss: - pass - elif is_supervised: - total_loss += tf.reduce_mean(selection_loss_per_example) - else: - # For the not supervised case, do not assign loss for cell selection - total_loss += tf.reduce_mean(selection_loss_per_example * (1.0 - aggregate_mask)) - - # Semi-supervised regression loss and supervised loss for aggregations - if self.config.num_aggregation_labels > 0: - if is_supervised: - # Note that `aggregate_mask` is None if the setting is supervised. - if aggregation_labels is not None: - assert shape_list(labels)[0] == shape_list(aggregation_labels)[0], ( - "Make sure the aggregation labels are a LongTensor of shape (batch_size,)" - ) - per_example_additional_loss = _calculate_aggregation_loss( - logits_aggregation, - aggregate_mask, - aggregation_labels, - self.config.use_answer_as_supervision, - self.config.num_aggregation_labels, - self.config.aggregation_loss_weight, - ) - else: - raise ValueError( - "You have to specify aggregation labels in order to calculate the aggregation loss" - ) - else: - aggregation_labels = tf.zeros(shape_list(labels)[0], dtype=tf.int32) - per_example_additional_loss = _calculate_aggregation_loss( - logits_aggregation, - aggregate_mask, - aggregation_labels, - self.config.use_answer_as_supervision, - self.config.num_aggregation_labels, - self.config.aggregation_loss_weight, - ) - - if self.config.use_answer_as_supervision: - if numeric_values is not None and numeric_values_scale is not None: - assert shape_list(numeric_values) == shape_list(numeric_values_scale) - # Add regression loss for numeric answers which require aggregation. - answer_loss, large_answer_loss_mask = _calculate_regression_loss( - float_answer, - aggregate_mask, - dist_per_token, - numeric_values, - numeric_values_scale, - table_mask_float, - logits_aggregation, - self.config, - ) - per_example_additional_loss += answer_loss - # Zero loss for examples with answer_loss > cutoff. - per_example_additional_loss *= large_answer_loss_mask - else: - raise ValueError( - "You have to specify numeric values and numeric values scale in order to calculate the" - " regression loss" - ) - total_loss += tf.reduce_mean(per_example_additional_loss) - - else: - # if no label ids are provided, set them to zeros in order to properly compute logits - labels = tf.zeros_like(logits) - _, logits = _single_column_cell_selection_loss( - logits, column_logits, labels, cell_index, col_index, cell_mask - ) - if not return_dict: - output = (logits, logits_aggregation) + outputs[2:] - return ((total_loss,) + output) if calculate_loss else output - - return TFTableQuestionAnsweringOutput( - loss=total_loss if calculate_loss else None, - logits=logits, - logits_aggregation=logits_aggregation, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "tapas", None) is not None: - with tf.name_scope(self.tapas.name): - self.tapas.build(None) - if getattr(self, "compute_token_logits", None) is not None: - with tf.name_scope(self.compute_token_logits.name): - self.compute_token_logits.build(None) - if getattr(self, "compute_column_logits", None) is not None: - with tf.name_scope(self.compute_column_logits.name): - self.compute_column_logits.build(None) - if getattr(self, "aggregation_classifier", None) is not None: - with tf.name_scope(self.aggregation_classifier.name): - self.aggregation_classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - Tapas Model with a sequence classification head on top (a linear layer on top of the pooled output), e.g. for table - entailment tasks, such as TabFact (Chen et al., 2020). - """, - TAPAS_START_DOCSTRING, -) -class TFTapasForSequenceClassification(TFTapasPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: TapasConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.tapas = TFTapasMainLayer(config, name="tapas") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout") - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Note: this is called - "classification_class_index" in the original implementation. - - Returns: - - Examples: - - ```python - >>> from transformers import AutoTokenizer, TapasForSequenceClassification - >>> import tensorflow as tf - >>> import pandas as pd - - >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-tabfact") - >>> model = TapasForSequenceClassification.from_pretrained("google/tapas-base-finetuned-tabfact") - - >>> data = { - ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], - ... "Age": ["56", "45", "59"], - ... "Number of movies": ["87", "53", "69"], - ... } - >>> table = pd.DataFrame.from_dict(data) - >>> queries = [ - ... "There is only one actor who is 45 years old", - ... "There are 3 actors which played in more than 60 movies", - ... ] - - >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="tf") - >>> labels = tf.convert_to_tensor([1, 0]) # 1 means entailed, 0 means refuted - - >>> outputs = model(**inputs, labels=labels) - >>> loss = outputs.loss - >>> logits = outputs.logits - ```""" - - outputs = self.tapas( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(inputs=pooled_output, training=training) - logits = self.classifier(inputs=pooled_output) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "tapas", None) is not None: - with tf.name_scope(self.tapas.name): - self.tapas.build(None) - if getattr(self, "dropout", None) is not None: - with tf.name_scope(self.dropout.name): - self.dropout.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -""" TAPAS utilities.""" - - -class AverageApproximationFunction(str, enum.Enum): - RATIO = "ratio" - FIRST_ORDER = "first_order" - SECOND_ORDER = "second_order" - - -# Beginning of everything related to segmented tensors - - -class IndexMap: - """Index grouping entries within a tensor.""" - - def __init__(self, indices, num_segments, batch_dims=0): - """ - Creates an index. - - Args: - indices: Tensor of indices, same shape as `values`. - num_segments: Scalar tensor, the number of segments. All elements - in a batched segmented tensor must have the same number of segments (although many segments can be empty). - batch_dims: Python integer, the number of batch dimensions. The first - `batch_dims` dimensions of a SegmentedTensor are treated as batch dimensions. Segments in different batch - elements are always distinct even if they have the same index. - """ - self.indices = tf.convert_to_tensor(indices) - self.num_segments = tf.convert_to_tensor(num_segments) - self.batch_dims = batch_dims - - def batch_shape(self): - return tf.shape(self.indices)[: self.batch_dims] - - -class ProductIndexMap(IndexMap): - """The product of two indices.""" - - def __init__(self, outer_index, inner_index): - """ - Combines indices i and j into pairs (i, j). The result is an index where each segment (i, j) is the - intersection of segments i and j. For example if the inputs represent table cells indexed by respectively rows - and columns the output will be a table indexed by (row, column) pairs, i.e. by cell. The implementation - combines indices {0, .., n - 1} and {0, .., m - 1} into {0, .., nm - 1}. The output has `num_segments` equal to - `outer_index.num_segements` * `inner_index.num_segments`. - - Args: - outer_index: IndexMap. - inner_index: IndexMap, must have the same shape as `outer_index`. - """ - if outer_index.batch_dims != inner_index.batch_dims: - raise ValueError("outer_index.batch_dims and inner_index.batch_dims must be the same.") - - super().__init__( - indices=( - inner_index.indices - + outer_index.indices * tf.cast(inner_index.num_segments, inner_index.indices.dtype) - ), - num_segments=inner_index.num_segments * outer_index.num_segments, - batch_dims=inner_index.batch_dims, - ) - self.outer_index = outer_index - self.inner_index = inner_index - - def project_outer(self, index): - """Projects an index with the same index set onto the outer components.""" - return IndexMap( - indices=tf.math.floordiv(index.indices, self.inner_index.num_segments), - num_segments=self.outer_index.num_segments, - batch_dims=index.batch_dims, - ) - - def project_inner(self, index): - """Projects an index with the same index set onto the inner components.""" - return IndexMap( - indices=tf.math.floormod(index.indices, self.inner_index.num_segments), - num_segments=self.inner_index.num_segments, - batch_dims=index.batch_dims, - ) - - -def gather(values, index, name="segmented_gather"): - """ - Gathers from `values` using the index map. For each element in the domain of the index map this operation looks up - a value for that index in `values`. Two elements from the same segment always get assigned the same value. - - Args: - values: [B1, ..., Bn, num_segments, V1, ...] Tensor with segment values. - index: [B1, ..., Bn, I1, ..., Ik] IndexMap. - name: Name for the TensorFlow operation. - - Returns: - [B1, ..., Bn, I1, ..., Ik, V1, ...] Tensor with the gathered values. - """ - return tf.gather(values, index.indices, batch_dims=index.batch_dims, name=name) - - -def flatten(index, name="segmented_flatten"): - """ - Flattens a batched index map to a 1d index map. This operation relabels the segments to keep batch elements - distinct. The k-th batch element will have indices shifted by `num_segments` * (k - 1). The result is a tensor with - `num_segments` multiplied by the number of elements in the batch. - - Args: - index: IndexMap to flatten. - name: Name for the TensorFlow operation. - - Returns: - The flattened IndexMap. - """ - batch_size = tf.reduce_prod(index.batch_shape()) - offset = tf.range(batch_size) * index.num_segments - offset = tf.reshape(offset, index.batch_shape()) - for _ in range(index.batch_dims, index.indices.shape.rank): - offset = tf.expand_dims(offset, -1) - - indices = tf.cast(offset, index.indices.dtype) + index.indices - return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0) - - -def range_index_map(batch_shape, num_segments, name="range_index_map"): - """ - Constructs an index map equal to range(num_segments). - - Args: - batch_shape (`tf.Tensor`): - Batch shape - num_segments (`int`): - Number of segments - name (`str`, *optional*, defaults to 'range_index_map'): - Name for the operation. Currently not used - - Returns: - (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments). - """ - batch_shape = tf.convert_to_tensor(batch_shape) - batch_shape.shape.assert_has_rank(1) - num_segments = tf.convert_to_tensor(num_segments) - num_segments.shape.assert_has_rank(0) - - indices = tf.range(num_segments) - shape = tf.concat([tf.ones_like(batch_shape, dtype=tf.int32), tf.expand_dims(num_segments, axis=0)], axis=0) - indices = tf.reshape(indices, shape) - multiples = tf.concat([batch_shape, [1]], axis=0) - indices = tf.tile(indices, multiples) - return IndexMap(indices=indices, num_segments=num_segments, batch_dims=batch_shape.shape.as_list()[0]) - - -def _segment_reduce(values, index, segment_reduce_fn, name): - """ - Applies a segment reduction segment-wise. - - Args: - values (`tf.Tensor`): - Tensor with segment values. - index (`IndexMap`): - IndexMap. - segment_reduce_fn (`str`): - Name for the reduce operation. One of "sum", "mean", "max" or "min". - name (`str`): - Name for the operation. Currently not used - - Returns: - (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments). - """ - # Flatten the batch dimensions, as segments ops do not support batching. - # However if `values` has extra dimensions to the right keep them - # unflattened. Segmented ops support vector-valued operations. - flat_index = flatten(index) - vector_shape = tf.shape(values)[index.indices.shape.rank :] - flattened_shape = tf.concat([[-1], vector_shape], axis=0) - flat_values = tf.reshape(values, flattened_shape) - segment_means = segment_reduce_fn( - data=flat_values, segment_ids=flat_index.indices, num_segments=flat_index.num_segments - ) - - # Unflatten the values. - new_shape = tf.concat([index.batch_shape(), [index.num_segments], vector_shape], axis=0) - output_values = tf.reshape(segment_means, new_shape) - output_index = range_index_map(index.batch_shape(), index.num_segments) - return output_values, output_index - - -def reduce_mean(values, index, name="segmented_reduce_mean"): - """ - Averages a tensor over its segments. Outputs 0 for empty segments. This operations computes the mean over segments, - with support for: - - - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices. - - Vectorization using the last dimension [V1, V2, ...]. If they are present the output will be a mean of vectors - rather than scalars. - Only the middle dimensions [I1, ..., Ik] are reduced by the operation. - - Args: - values: [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..] tensor of values to be - averaged. - index: IndexMap [B1, B2, ..., Bn, I1, .., Ik] index defining the segments. - name: Name for the TensorFlow ops. - - Returns: - A pair (output_values, output_index) where `output_values` is a tensor of shape [B1, B2, ..., Bn, num_segments, - V1, V2, ..] and `index` is an IndexMap with shape [B1, B2, ..., Bn, num_segments]. - """ - return _segment_reduce(values, index, tf.math.unsorted_segment_mean, name) - - -def reduce_sum(values, index, name="segmented_reduce_sum"): - """ - Sums a tensor over its segments. Outputs 0 for empty segments. This operations computes the sum over segments, with - support for: - - - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices. - - Vectorization using the last dimension [V1, V2, ...]. If they are present the output will be a sum of vectors - rather than scalars. - Only the middle dimensions [I1, ..., Ik] are reduced by the operation. - - Args: - values: [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..] tensor of values to be - averaged. - index: IndexMap [B1, B2, ..., Bn, I1, .., Ik] index defining the segments. - name: Name for the TensorFlow ops. - - Returns: - A pair (output_values, output_index) where `output_values` is a tensor of shape [B1, B2, ..., Bn, num_segments, - V1, V2, ..] and `index` is an IndexMap with shape [B1, B2, ..., Bn, num_segments]. - """ - return _segment_reduce(values, index, tf.math.unsorted_segment_sum, name) - - -def reduce_max(values, index, name="segmented_reduce_max"): - """ - Computes the maximum over segments. This operations computes the maximum over segments, with support for: - - - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices. - - Vectorization using the last dimension [V1, V2, ...]. If they are present the output will be an element-wise - maximum of vectors rather than scalars. - Only the middle dimensions [I1, ..., Ik] are reduced by the operation. - - Args: - values: [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..] tensor of values to be - averaged. - index: IndexMap [B1, B2, ..., Bn, I1, .., Ik] index defining the segments. - name: Name for the TensorFlow ops. - - Returns: - A pair (output_values, output_index) where `output_values` is a tensor of shape [B1, B2, ..., Bn, num_segments, - V1, V2, ..] and `index` is an IndexMap with shape [B1, B2, ..., Bn, num_segments]. - """ - return _segment_reduce(values, index, tf.math.unsorted_segment_max, name) - - -def reduce_min(values, index, name="segmented_reduce_min"): - """Computes the minimum over segments.""" - return _segment_reduce(values, index, tf.math.unsorted_segment_min, name) - - -def _single_column_cell_selection_loss(token_logits, column_logits, labels, cell_index, col_index, cell_mask): - """ - Computes the loss for cell selection constrained to a single column. The loss is a hierarchical log-likelihood. The - model first predicts a column and then selects cells within that column (conditioned on the column). Cells outside - the selected column are never selected. - - Args: - token_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Tensor containing the logits per token. - column_logits (`tf.Tensor` of shape `(batch_size, max_num_cols)`): - Tensor containing the logits per column. - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Labels per token. - cell_index (`ProductIndexMap`): - Index that groups tokens into cells. - col_index (`IndexMap`): - Index that groups tokens into columns. - cell_mask (`tf.Tensor` of shape `(batch_size, max_num_rows * max_num_cols)`): - Mask for cells that exist in the table (i.e. that are not padding). - - Returns: - selection_loss_per_example (`tf.Tensor` of shape `(batch_size,)`): Loss for each example. logits (`tf.Tensor` - of shape `(batch_size, sequence_length)`): New logits which are only allowed to select cells in a single - column. Logits outside of the most likely column according to *column_logits* will be set to a very low value - (such that the probabilities are 0). - """ - # First find the column we should select. We use the column with maximum - # number of selected cells. - labels_per_column, _ = reduce_sum(tf.cast(labels, tf.float32), col_index) - column_label = tf.argmax(labels_per_column, axis=-1, output_type=tf.int32) - # Check if there are no selected cells in the column. In that case the model - # should predict the special column id 0, which means "select nothing". - no_cell_selected = tf.equal(tf.reduce_max(labels_per_column, axis=-1), 0) - column_label = tf.where(no_cell_selected, tf.zeros_like(column_label), column_label) - - column_dist = tfp.distributions.Categorical(logits=column_logits) - column_loss_per_example = -column_dist.log_prob(column_label) - - # Reduce the labels and logits to per-cell from per-token. - logits_per_cell, _ = reduce_mean(token_logits, cell_index) - labels_per_cell, labels_index = reduce_max(tf.cast(labels, tf.int32), cell_index) - - # Mask for the selected column. - column_id_for_cells = cell_index.project_inner(labels_index).indices - column_mask = tf.cast(tf.equal(column_id_for_cells, tf.expand_dims(column_label, axis=1)), tf.float32) - - # Compute the log-likelihood for cells, but only for the selected column. - cell_dist = tfp.distributions.Bernoulli(logits=logits_per_cell) - cell_log_prob = cell_dist.log_prob(labels_per_cell) - cell_loss = -tf.reduce_sum(cell_log_prob * column_mask * cell_mask, axis=1) - # We need to normalize the loss by the number of cells in the column. - cell_loss /= tf.reduce_sum(column_mask * cell_mask, axis=1) + EPSILON_ZERO_DIVISION - - selection_loss_per_example = column_loss_per_example - selection_loss_per_example += tf.where(no_cell_selected, tf.zeros_like(selection_loss_per_example), cell_loss) - - # Set the probs outside the selected column (selected by the *model*) - # to 0. This ensures backwards compatibility with models that select - # cells from multiple columns. - selected_column_id = tf.argmax(column_logits, axis=-1, output_type=tf.int32) - selected_column_mask = tf.cast( - tf.equal(column_id_for_cells, tf.expand_dims(selected_column_id, axis=-1)), tf.float32 - ) - # Never select cells with the special column id 0. - selected_column_mask = tf.where( - tf.equal(column_id_for_cells, 0), tf.zeros_like(selected_column_mask), selected_column_mask - ) - logits_per_cell += CLOSE_ENOUGH_TO_LOG_ZERO * (1.0 - cell_mask * selected_column_mask) - logits = gather(logits_per_cell, cell_index) - - return selection_loss_per_example, logits - - -def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference, labels, aggregation_classifier): - """ - Finds examples where the model should select cells with no aggregation. - - Returns a mask that determines for which examples should the model select answers directly from the table, without - any aggregation function. If the answer is a piece of text the case is unambiguous as aggregation functions only - apply to numbers. If the answer is a number but does not appear in the table then we must use some aggregation - case. The ambiguous case is when the answer is a number that also appears in the table. In this case we use the - aggregation function probabilities predicted by the model to decide whether to select or aggregate. The threshold - for this is a hyperparameter *cell_selection_preference* - - Args: - answer (`tf.Tensor` of shape `(batch_size, )`): - Answer for every example in the batch. Nan if there is no scalar answer. - pooled_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): - Output of the pooler (BertPooler) on top of the encoder layer. - cell_selection_preference (`float`): - Preference for cell selection in ambiguous cases. - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Labels per token. aggregation_classifier (`torch.nn.Linear`): Aggregation head - - Returns: - aggregate_mask (`tf.Tensor` of shape `(batch_size,)`): A mask set to 1 for examples that should use aggregation - functions. - """ - # tf.Tensor(batch_size,) - aggregate_mask_init = tf.cast(tf.logical_not(tf.math.is_nan(answer)), tf.float32) - logits_aggregation = aggregation_classifier(pooled_output) - dist_aggregation = tfp.distributions.Categorical(logits=logits_aggregation) - # Index 0 corresponds to "no aggregation". - aggregation_ops_total_mass = tf.reduce_sum(dist_aggregation.probs_parameter()[:, 1:], axis=1) - # Cell selection examples according to current model. - is_pred_cell_selection = aggregation_ops_total_mass <= cell_selection_preference - # Examples with non-empty cell selection supervision. - is_cell_supervision_available = tf.reduce_sum(labels, axis=1) > 0 - aggregate_mask = tf.where( - tf.logical_and(is_pred_cell_selection, is_cell_supervision_available), - tf.zeros_like(aggregate_mask_init, dtype=tf.float32), - aggregate_mask_init, - ) - aggregate_mask = tf.stop_gradient(aggregate_mask) - return aggregate_mask - - -def _calculate_aggregation_loss_known( - logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels -): - """ - Calculates aggregation loss when its type is known during training. - - In the weakly supervised setting, the only known information is that for cell selection examples, "no aggregation" - should be predicted. For other examples (those that require aggregation), no loss is accumulated. In the setting - where aggregation type is always known, standard cross entropy loss is accumulated for all examples - - Args: - logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`): - Logits per aggregation operation. - aggregate_mask (`tf.Tensor` of shape `(batch_size, )`): - A mask set to 1 for examples that should use aggregation functions. - aggregation_labels (`tf.Tensor` of shape `(batch_size, )`): - Aggregation function id for every example in the batch. - use_answer_as_supervision (`bool`, *optional*): - Whether to use the answer as the only supervision for aggregation examples. - num_aggregation_labels (`int`, *optional*, defaults to 0): - The number of aggregation operators to predict. - - Returns: - aggregation_loss_known (`tf.Tensor` of shape `(batch_size,)`): Aggregation loss (when its type is known during - training) per example. - """ - if use_answer_as_supervision: - # Prepare "no aggregation" targets for cell selection examples. - target_aggregation = tf.zeros_like(aggregate_mask, dtype=tf.int32) - else: - # Use aggregation supervision as the target. - target_aggregation = aggregation_labels - - one_hot_labels = tf.one_hot(target_aggregation, depth=num_aggregation_labels, dtype=tf.float32) - log_probs = tf.nn.log_softmax(logits_aggregation, axis=-1) - - # [batch_size] - per_example_aggregation_intermediate = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) - if use_answer_as_supervision: - # Accumulate loss only for examples requiring cell selection - # (no aggregation). - return per_example_aggregation_intermediate * (1 - aggregate_mask) - else: - return per_example_aggregation_intermediate - - -def _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask): - """ - Calculates aggregation loss in the case of answer supervision. - - Args: - logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`): - Logits per aggregation operation. - aggregate_mask (`tf.Tensor` of shape `(batch_size, )`): - A mask set to 1 for examples that should use aggregation functions - - Returns: - aggregation_loss_unknown (`tf.Tensor` of shape `(batch_size,)`): Aggregation loss (in case of answer - supervision) per example. - """ - dist_aggregation = tfp.distributions.Categorical(logits=logits_aggregation) - # Index 0 corresponds to "no aggregation". - aggregation_ops_total_mass = tf.reduce_sum(dist_aggregation.probs_parameter()[:, 1:], axis=1) - # Predict some aggregation in case of an answer that needs aggregation. - # This increases the probability of all aggregation functions, in a way - # similar to MML, but without considering whether the function gives the - # correct answer. - return -tf.math.log(aggregation_ops_total_mass) * aggregate_mask - - -def _calculate_aggregation_loss( - logits_aggregation, - aggregate_mask, - aggregation_labels, - use_answer_as_supervision, - num_aggregation_labels, - aggregation_loss_weight, -): - """ - Calculates the aggregation loss per example. - - Args: - logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`): - Logits per aggregation operation. - aggregate_mask (`tf.Tensor` of shape `(batch_size, )`): - A mask set to 1 for examples that should use aggregation functions. - aggregation_labels (`tf.Tensor` of shape `(batch_size, )`): - Aggregation function id for every example in the batch. - use_answer_as_supervision (`bool`, *optional*): - Whether to use the answer as the only supervision for aggregation examples. - num_aggregation_labels (`int`, *optional*, defaults to 0): - The number of aggregation operators to predict. - aggregation_loss_weight (`float`, *optional*, defaults to 1.0): - Importance weight for the aggregation loss. - - Returns: - aggregation_loss (`tf.Tensor` of shape `(batch_size,)`): Aggregation loss per example. - """ - per_example_aggregation_loss = _calculate_aggregation_loss_known( - logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels - ) - - if use_answer_as_supervision: - # Add aggregation loss for numeric answers that need aggregation. - per_example_aggregation_loss += _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask) - return aggregation_loss_weight * per_example_aggregation_loss - - -def _calculate_expected_result( - dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config -): - """ - Calculates the expected result given cell and aggregation probabilities. - - Args: - dist_per_cell (`tfp.distributions.Bernoulli`): - Cell selection distribution for each cell. - numeric_values (`tf.Tensor` of shape `(batch_size, seq_length)`): - Numeric values of every token. Nan for tokens which are not numeric values. - numeric_values_scale (`tf.Tensor` of shape `(batch_size, seq_length)`): - Scale of the numeric values of every token. - input_mask_float (`tf.Tensor` of shape `(batch_size, seq_length)`): - Mask for the table, without question tokens and table headers. - logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`): - Logits per aggregation operation. - config ([`TapasConfig`]): - Model configuration class with all the hyperparameters of the model - - Returns: - expected_result (`tf.Tensor` of shape `(batch_size,)`): The expected result per example. - """ - if config.use_gumbel_for_cells: - gumbel_dist = tfp.distributions.RelaxedBernoulli( - # The token logits where already divided by the temperature and used for - # computing cell selection errors so we need to multiply it again here - config.temperature, - logits=dist_per_cell.logits_parameter() * config.temperature, - ) - scaled_probability_per_cell = gumbel_dist.sample() - else: - scaled_probability_per_cell = dist_per_cell.probs_parameter() - - # [batch_size, seq_length] - scaled_probability_per_cell = (scaled_probability_per_cell / numeric_values_scale) * input_mask_float - count_result = tf.reduce_sum(scaled_probability_per_cell, axis=1) - numeric_values_masked = tf.where( - tf.math.is_nan(numeric_values), tf.zeros_like(numeric_values), numeric_values - ) # Mask non-numeric table values to zero. - sum_result = tf.reduce_sum(scaled_probability_per_cell * numeric_values_masked, axis=1) - avg_approximation = config.average_approximation_function - if avg_approximation == AverageApproximationFunction.RATIO: - average_result = sum_result / (count_result + EPSILON_ZERO_DIVISION) - elif avg_approximation == AverageApproximationFunction.FIRST_ORDER: - # The sum of all probabilities except that correspond to other cells - ex = tf.reduce_sum(scaled_probability_per_cell, axis=1, keepdims=True) - scaled_probability_per_cell + 1 - average_result = tf.reduce_sum(numeric_values_masked * scaled_probability_per_cell / ex, axis=1) - elif avg_approximation == AverageApproximationFunction.SECOND_ORDER: - # The sum of all probabilities except that correspond to other cells - ex = tf.reduce_sum(scaled_probability_per_cell, axis=1, keepdims=True) - scaled_probability_per_cell + 1 - pointwise_var = scaled_probability_per_cell * (1 - scaled_probability_per_cell) - var = tf.reduce_sum(pointwise_var, axis=1, keepdims=True) - pointwise_var - multiplier = (var / tf.math.square(ex) + 1) / ex - average_result = tf.reduce_sum(numeric_values_masked * scaled_probability_per_cell * multiplier, axis=1) - else: - raise ValueError("Invalid average_approximation_function: %s", config.average_approximation_function) - - if config.use_gumbel_for_aggregation: - gumbel_dist = tfp.distributions.RelaxedOneHotCategorical( - config.aggregation_temperature, logits=logits_aggregation[:, 1:] - ) - # [batch_size, num_aggregation_labels - 1] - aggregation_op_only_probs = gumbel_dist.sample() - else: - # [batch_size, num_aggregation_labels - 1] - aggregation_op_only_probs = stable_softmax(logits_aggregation[:, 1:] / config.aggregation_temperature, axis=-1) - all_results = tf.concat( - [ - tf.expand_dims(sum_result, axis=1), - tf.expand_dims(average_result, axis=1), - tf.expand_dims(count_result, axis=1), - ], - axis=1, - ) - expected_result = tf.reduce_sum(all_results * aggregation_op_only_probs, axis=1) - return expected_result - - -def _calculate_regression_loss( - answer, - aggregate_mask, - dist_per_cell, - numeric_values, - numeric_values_scale, - input_mask_float, - logits_aggregation, - config, -): - """ - Calculates the regression loss per example. - - Args: - answer (`tf.Tensor` of shape `(batch_size,)`): - Answer for every example in the batch. Nan if there is no scalar answer. - aggregate_mask (`tf.Tensor` of shape `(batch_size,)`): - A mask set to 1 for examples that should use aggregation functions. - dist_per_cell (`torch.distributions.Bernoulli`): - Cell selection distribution for each cell. - numeric_values (`tf.Tensor` of shape `(batch_size, seq_length)`): - Numeric values of every token. Nan for tokens which are not numeric values. - numeric_values_scale (`tf.Tensor` of shape `(batch_size, seq_length)`): - Scale of the numeric values of every token. - input_mask_float (`tf.Tensor` of shape `(batch_size, seq_length)`): - Mask for the table, without question tokens and table headers. - logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`): - Logits per aggregation operation. - config ([`TapasConfig`]): - Model configuration class with all the parameters of the model - - Returns: - per_example_answer_loss_scaled (`tf.Tensor` of shape `(batch_size,)`): Scales answer loss for each example in - the batch. large_answer_loss_mask (`tf.Tensor` of shape `(batch_size,)`): A mask which is 1 for examples for - which their answer loss is larger than the answer_loss_cutoff. - """ - # float32 (batch_size,) - expected_result = _calculate_expected_result( - dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config - ) - - # [batch_size] - answer_masked = tf.where(tf.math.is_nan(answer), tf.zeros_like(answer), answer) - - if config.use_normalized_answer_loss: - normalizer = tf.stop_gradient( - tf.math.maximum(tf.math.abs(expected_result), tf.math.abs(answer_masked)) + EPSILON_ZERO_DIVISION - ) - normalized_answer_masked = answer_masked / normalizer - normalized_expected_result = expected_result / normalizer - per_example_answer_loss = tf.compat.v1.losses.huber_loss( - normalized_answer_masked * aggregate_mask, - normalized_expected_result * aggregate_mask, - delta=tf.cast(1.0, tf.float32), - reduction=tf.losses.Reduction.NONE, - ) - else: - per_example_answer_loss = tf.compat.v1.losses.huber_loss( - answer_masked * aggregate_mask, - expected_result * aggregate_mask, - delta=tf.cast(config.huber_loss_delta, tf.float32), - reduction=tf.losses.Reduction.NONE, - ) - if config.answer_loss_cutoff is None: - large_answer_loss_mask = tf.ones_like(per_example_answer_loss, dtype=tf.float32) - else: - large_answer_loss_mask = tf.where( - per_example_answer_loss > config.answer_loss_cutoff, - tf.zeros_like(per_example_answer_loss, dtype=tf.float32), - tf.ones_like(per_example_answer_loss, dtype=tf.float32), - ) - per_example_answer_loss_scaled = config.answer_loss_importance * (per_example_answer_loss * aggregate_mask) - return per_example_answer_loss_scaled, large_answer_loss_mask - - -__all__ = [ - "TFTapasForMaskedLM", - "TFTapasForQuestionAnswering", - "TFTapasForSequenceClassification", - "TFTapasModel", - "TFTapasPreTrainedModel", -] diff --git a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py deleted file mode 100644 index a59c799cc04a..000000000000 --- a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py +++ /dev/null @@ -1,864 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Classes to support Vision-Encoder-Text-Decoder architectures""" - -import os -from typing import Optional, Union - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput -from ...modeling_flax_utils import FlaxPreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from ..auto.configuration_auto import AutoConfig -from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM -from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig" - -VISION_ENCODER_DECODER_START_DOCSTRING = r""" - This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model - as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via - [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`] - function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream - generative task, like image captioning. - - The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation - tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation - Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi - Zhou, Wei Li, Peter J. Liu. - - Additionally, in [TrOCR: Transformer-based Optical Character Recognition with Pre-trained - Models](https://huggingface.co/papers/2109.10282) it is shown how leveraging large pretrained vision models for optical - character recognition (OCR) yields a significant performance improvement. - - After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any - other models (see the examples for more information). - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Parameters: - config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using the vision model's image processor. For example, using - [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] for details. - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - decoder_position_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.decoder.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple. -""" - -VISION_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using the vision model's image processor. For example, using - [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple. -""" - -VISION_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - For sequence to sequence training, `decoder_input_ids` should be provided. If no `decoder_input_ids` is - provided, the model will create this tensor by shifting the `input_ids` to the right for denoising - pre-training. - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - decoder_position_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.decoder.max_position_embeddings - 1]`. - past_key_values (`dict[str, jnp.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a - plain tuple. -""" - - -class FlaxVisionEncoderDecoderModule(nn.Module): - config: VisionEncoderDecoderConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - encoder_config = self.config.encoder - decoder_config = self.config.decoder - - # Copied from `modeling_hybrid_clip.py` with modifications. - from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING - - encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class - decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class - - self.encoder = encoder_module(encoder_config, dtype=self.dtype) - self.decoder = decoder_module(decoder_config, dtype=self.dtype) - - # encoder outputs might need to be projected to different dimension for decoder - if ( - self.encoder.config.hidden_size != self.decoder.config.hidden_size - and self.decoder.config.cross_attention_hidden_size is None - ): - self.enc_to_dec_proj = nn.Dense( - self.decoder.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), - dtype=self.dtype, - ) - else: - self.enc_to_dec_proj = None - - def _get_encoder_module(self): - return self.encoder - - def _get_projection_module(self): - return self.enc_to_dec_proj - - def _get_decoder_module(self): - return self.decoder - - def __call__( - self, - pixel_values, - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - encoder_outputs = self.encoder( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - encoder_hidden_states = encoder_outputs[0] - - # optionally project encoder_hidden_states - if self.enc_to_dec_proj is not None: - encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) - - # The advantage of explicitly setting this is TPU XLA compiler knows as soon as possible what shape this - # variable has and can better optimize. Also passing `None` can lead to some problems when jitting the model. - # In Flax/JAX, we only want to pass `None` for non-tensor function inputs. For all tensor function inputs, we - # should always pass a tensor and not `None`. - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqLMOutput( - logits=decoder_outputs.logits, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING) -class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): - r""" - [`FlaxVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture - with the module (flax.nn.Module) of one of the base vision model classes of the library as encoder module and - another one as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method - for the encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder. - """ - - config_class = VisionEncoderDecoderConfig - base_model_prefix = "vision_encoder_decoder" - main_input_name = "pixel_values" - module_class = FlaxVisionEncoderDecoderModule - - def __init__( - self, - config: VisionEncoderDecoderConfig, - input_shape: Optional[tuple] = None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - if not _do_init: - raise ValueError( - "`FlaxVisionEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." - ) - - if input_shape is None: - num_channels = getattr(config.encoder, "num_channels", 3) - input_shape = ( - (1, config.encoder.image_size, config.encoder.image_size, num_channels), - (1, 1), - ) - - if config.decoder.cross_attention_hidden_size is not None: - if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: - raise ValueError( - "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" - f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" - f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" - " `config.encoder.hidden_size`." - ) - - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - encoder_input_shape, decoder_input_shape = input_shape - - # init input tensors - pixel_values = jnp.zeros(encoder_input_shape, dtype=self.dtype) - decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - batch_size, _, _, _ = pixel_values.shape - decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape - if not decoder_batch_size == batch_size: - raise ValueError( - f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder " - f"and {decoder_batch_size} for decoder." - ) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length) - ) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - pixel_values, - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(VISION_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC) - def encode( - self, - pixel_values: jnp.ndarray, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoImageProcessor, FlaxVisionEncoderDecoderModel - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") - - >>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized - >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained( - ... "google/vit-base-patch16-224-in21k", "openai-community/gpt2" - ... ) - - >>> pixel_values = image_processor(images=image, return_tensors="np").pixel_values - >>> encoder_outputs = model.encode(pixel_values) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # `FlaxViTModel` expects channel first format, but `FlaxViTModule` expects channel last format. - # Currently, we assume this holds for all Flax vision models, and perform a transpose here. - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, pixel_values, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(pixel_values, **kwargs) - - outputs = self.module.apply( - {"params": params or self.params}, - pixel_values=jnp.array(pixel_values, dtype=self.dtype), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - if return_dict: - outputs = FlaxBaseModelOutput( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - return outputs - - @add_start_docstrings(VISION_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) - def decode( - self, - decoder_input_ids, - encoder_outputs, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import AutoImageProcessor, FlaxVisionEncoderDecoderModel - >>> import jax.numpy as jnp - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") - - >>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized - >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained( - ... "google/vit-base-patch16-224-in21k", "openai-community/gpt2" - ... ) - - >>> pixel_values = image_processor(images=image, return_tensors="np").pixel_values - >>> encoder_outputs = model.encode(pixel_values) - - >>> decoder_start_token_id = model.config.decoder.bos_token_id - >>> decoder_input_ids = jnp.ones((pixel_values.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxBartAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward( - module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs - ): - projection_module = module._get_projection_module() - decoder_module = module._get_decoder_module() - - # optionally project encoder_hidden_states - if projection_module is not None: - encoder_hidden_states = projection_module(encoder_hidden_states) - - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - encoder_hidden_states, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - @add_start_docstrings_to_model_forward(VISION_ENCODER_DECODER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - def __call__( - self, - pixel_values: jnp.ndarray, - decoder_input_ids: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Examples: - - ```python - >>> from transformers import FlaxVisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") - - >>> # load output tokenizer - >>> tokenizer_output = AutoTokenizer.from_pretrained("openai-community/gpt2") - - >>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized - >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained( - ... "google/vit-base-patch16-224-in21k", "openai-community/gpt2" - ... ) - - >>> pixel_values = image_processor(images=image, return_tensors="np").pixel_values - - >>> # use GPT2's eos_token as the pad as well as eos token - >>> model.config.eos_token_id = model.config.decoder.eos_token_id - >>> model.config.pad_token_id = model.config.eos_token_id - - >>> # generation - >>> sequences = model.generate(pixel_values, num_beams=4, max_length=12).sequences - - >>> captions = tokenizer_output.batch_decode(sequences, skip_special_tokens=True) - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - - # `FlaxViTModel` expects channel first format, but `FlaxViTModule` expects channel last format. - # Currently, we assume this holds for all Flax vision models, and perform a transpose here. - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) - - # prepare decoder inputs - if decoder_input_ids is None: - raise ValueError("`decoder_input_ids` can't be `None`.") - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - if decoder_position_ids is None: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - pixel_values=jnp.array(pixel_values, dtype=self.dtype), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - decoder_position_ids = jnp.broadcast_to( - jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) - ) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": decoder_position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 - return model_kwargs - - @classmethod - def from_encoder_decoder_pretrained( - cls, - encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, - decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, - *model_args, - **kwargs, - ) -> FlaxPreTrainedModel: - r""" - Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model - checkpoints. - - Params: - encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*): - Information necessary to initiate the encoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An - example is `google/vit-base-patch16-224-in21k`. - - A path to a *directory* containing model weights saved using - [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`): - Information necessary to initiate the decoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - model_args (remaining positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). - - - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. - - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. - - To update the parent model configuration, do not use a prefix for each configuration parameter. - - Behaves differently depending on whether a `config` is provided or automatically loaded. - - Example: - - ```python - >>> from transformers import FlaxVisionEncoderDecoderModel - - >>> # initialize a vit-gpt2 from a pretrained ViT and a pretrained GPT2 model. Note that the cross-attention layers will be randomly initialized - >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained( - ... "google/vit-base-patch16-224-in21k", "openai-community/gpt2" - ... ) - >>> # saving model after fine-tuning - >>> model.save_pretrained("./vit-gpt2") - >>> # load fine-tuned model - >>> model = FlaxVisionEncoderDecoderModel.from_pretrained("./vit-gpt2") - ```""" - - kwargs_encoder = { - argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") - } - - kwargs_decoder = { - argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") - } - - # remove encoder, decoder kwargs from kwargs - for key in kwargs_encoder: - del kwargs["encoder_" + key] - for key in kwargs_decoder: - del kwargs["decoder_" + key] - - # Load and initialize the encoder and decoder - # The distinction between encoder and decoder at the model level is made - # by the value of the flag `is_decoder` that we need to set correctly. - encoder = kwargs_encoder.pop("model", None) - if encoder is None: - if encoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_encoder: - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) - if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( - f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " - "from a decoder model. Cross-attention and causal mask are disabled." - ) - encoder_config.is_decoder = False - encoder_config.add_cross_attention = False - - kwargs_encoder["config"] = encoder_config - - encoder = FlaxAutoModel.from_pretrained( - encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder - ) - - decoder = kwargs_decoder.pop("model", None) - if decoder is None: - if decoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_decoder: - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) - if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: - logger.info( - f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" - f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" - f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." - ) - decoder_config.is_decoder = True - decoder_config.add_cross_attention = True - - kwargs_decoder["config"] = decoder_config - - if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: - logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " - f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " - "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " - "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " - "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" - ) - - decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) - - # instantiate config with corresponding kwargs - dtype = kwargs.pop("dtype", jnp.float32) - config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) - - # init model - model = cls(config, dtype=dtype) - model.params["encoder"] = encoder.params - model.params["decoder"] = decoder.params - - return model - - -__all__ = ["FlaxVisionEncoderDecoderModel"] diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py deleted file mode 100644 index ef2ea2109987..000000000000 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ /dev/null @@ -1,696 +0,0 @@ -# coding=utf-8 -# Copyright 2022 HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Classes to support TF Vision-Encoder-Text-Decoder architectures""" - -from __future__ import annotations - -import re -import warnings - -import numpy as np -import tensorflow as tf - -from ...configuration_utils import PretrainedConfig -from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput -from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, get_initializer, keras, unpack_inputs -from ...tf_utils import shape_list -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from ..auto.configuration_auto import AutoConfig -from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM -from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig" - -DEPRECATION_WARNING = ( - "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the" - " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if" - " fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the" - " labels, no need to pass them yourself anymore." -) - -VISION_ENCODER_DECODER_START_DOCSTRING = r""" - This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model - as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via - [`~TFAutoModel.from_pretrained`] function and the decoder is loaded via [`~TFAutoModelForCausalLM.from_pretrained`] - function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream - generative task, like image captioning. - - The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation - tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation - Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi - Zhou, Wei Li, Peter J. Liu. - - Additionally, in [TrOCR: Transformer-based Optical Character Recognition with Pre-trained - Models](https://huggingface.co/papers/2109.10282) it is shown how leveraging large pretrained vision models for optical - character recognition (OCR) yields a significant performance improvement. - - After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any - other models (see the examples for more information). - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using the vision's model's image processor. For example, using - [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] for details. - decoder_input_ids (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - Provide for sequence to sequence training to the decoder. Indices can be obtained using - [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for - details. - decoder_attention_mask (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*): - This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` (`tf.Tensor` of shape `({0}, hidden_size)`) is a tensor of hidden-states at the output - of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `({0})`. - decoder_inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. This is useful if you want more control over how to convert `decoder_input_ids` indices - into associated vectors than the model's internal embedding lookup matrix. - labels (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0, - ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: - - - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. - - With a *decoder_* prefix which will be input as `**decoder_kwargs` for the decoder forward function. -""" - - -# Copied from transformers.models.encoder_decoder.modeling_tf_encoder_decoder.shift_tokens_right -def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - if pad_token_id is None: - raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") - pad_token_id = tf.cast(pad_token_id, input_ids.dtype) - - if decoder_start_token_id is None: - raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") - decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) - - start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING) -class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): - r""" - [`TFVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture - with one of the base vision model classes of the library as encoder and another one of the base model classes as - decoder when created with the [`~TFAutoModel.from_pretrained`] class method for the encoder and - [`~TFAutoModelForCausalLM.from_pretrained`] class method for the decoder. - """ - - config_class = VisionEncoderDecoderConfig - base_model_prefix = "vision_encoder_decoder" - load_weight_prefix = "tf_vision_encoder_decoder_model" - main_input_name = "pixel_values" - - def __init__( - self, - config: PretrainedConfig | None = None, - encoder: TFPreTrainedModel | None = None, - decoder: TFPreTrainedModel | None = None, - ): - if config is None and (encoder is None or decoder is None): - raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") - if config is None: - config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) - else: - if not isinstance(config, self.config_class): - raise ValueError(f"config: {config} has to be of type {self.config_class}") - - if config.decoder.cross_attention_hidden_size is not None: - if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: - raise ValueError( - "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" - f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" - f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" - " `config.encoder.hidden_size`." - ) - - # initialize with config - super().__init__(config) - - if encoder is None: - encoder = TFAutoModel.from_config(config.encoder, name="encoder") - - if decoder is None: - decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder") - - self.encoder = encoder - self.decoder = decoder - - if self.encoder.config.to_dict() != self.config.encoder.to_dict(): - logger.warning( - f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" - f" {self.config.encoder}" - ) - if self.decoder.config.to_dict() != self.config.decoder.to_dict(): - logger.warning( - f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" - f" {self.config.decoder}" - ) - - # make sure that the individual model's config refers to the shared config - # so that the updates to the config will be synced - self.encoder.config = self.config.encoder - self.decoder.config = self.config.decoder - - # encoder outputs might need to be projected to different dimension for decoder - if ( - self.encoder.config.hidden_size != self.decoder.config.hidden_size - and self.decoder.config.cross_attention_hidden_size is None - ): - self.enc_to_dec_proj = keras.layers.Dense( - units=self.decoder.config.hidden_size, - kernel_initializer=get_initializer(config.encoder.initializer_range), - name="enc_to_dec_proj", - ) - - if self.encoder.get_output_embeddings() is not None: - raise ValueError( - f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" - ) - - @property - def input_signature(self): - vision_config = self.config.encoder - if hasattr(vision_config, "vision_config"): - vision_config = vision_config.vision_config - if hasattr(vision_config, "image_size"): - image_size = vision_config.image_size - else: - image_size = vision_config.input_size - return { - "pixel_values": tf.TensorSpec( - shape=( - None, - vision_config.num_channels, - image_size, - image_size, - ), - dtype=tf.float32, - ), - "decoder_input_ids": tf.TensorSpec(shape=(None, None), dtype=tf.int32, name="decoder_input_ids"), - } - - def get_encoder(self): - return self.encoder - - def get_input_embeddings(self): - return self.encoder.get_input_embeddings() - - def get_output_embeddings(self): - return self.decoder.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - return self.decoder.set_output_embeddings(new_embeddings) - - def tf_to_pt_weight_rename(self, tf_weight): - # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models - # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. - # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption - # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's - # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name! - - # This override is only needed in the case where we're crossloading weights from PT. However, since weights are - # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file. - # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it - # or not. - encoder_model_type = self.config.encoder.model_type - if "encoder" in tf_weight and "decoder" not in tf_weight: - return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),) - else: - return (tf_weight,) - - @classmethod - def from_encoder_decoder_pretrained( - cls, - encoder_pretrained_model_name_or_path: str | None = None, - decoder_pretrained_model_name_or_path: str | None = None, - *model_args, - **kwargs, - ) -> TFPreTrainedModel: - r""" - Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model - checkpoints. - - - Params: - encoder_pretrained_model_name_or_path (`str`, *optional*): - Information necessary to initiate the encoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An - example is `google/vit-base-patch16-224-in21k`. - - A path to a *directory* containing model weights saved using - [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case, - `encoder_from_pt` should be set to `True`. - - decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to *None*): - Information necessary to initiate the decoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case, - `decoder_from_pt` should be set to `True`. - - model_args (remaining positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). - - - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. - - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. - - To update the parent model configuration, do not use a prefix for each configuration parameter. - - Behaves differently depending on whether a `config` is provided or automatically loaded. - - Example: - - ```python - >>> from transformers import TFVisionEncoderDecoderModel - - >>> # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized - >>> model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained( - ... "google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased" - ... ) - >>> # saving model after fine-tuning - >>> model.save_pretrained("./vit-bert") - >>> # load fine-tuned model - >>> model = TFVisionEncoderDecoderModel.from_pretrained("./vit-bert") - ```""" - - kwargs_encoder = { - argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") - } - - kwargs_decoder = { - argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") - } - - # remove encoder, decoder kwargs from kwargs - for key in kwargs_encoder: - del kwargs["encoder_" + key] - for key in kwargs_decoder: - del kwargs["decoder_" + key] - - # Load and initialize the encoder and decoder - # The distinction between encoder and decoder at the model level is made - # by the value of the flag `is_decoder` that we need to set correctly. - encoder = kwargs_encoder.pop("model", None) - if encoder is None: - if encoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_encoder: - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) - if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( - f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " - "from a decoder model. Cross-attention and causal mask are disabled." - ) - encoder_config.is_decoder = False - encoder_config.add_cross_attention = False - - kwargs_encoder["config"] = encoder_config - - kwargs_encoder["name"] = "encoder" - kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix - encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) - - decoder = kwargs_decoder.pop("model", None) - if decoder is None: - if decoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_decoder: - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) - if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: - logger.info( - f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" - f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" - f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." - ) - decoder_config.is_decoder = True - decoder_config.add_cross_attention = True - - kwargs_decoder["config"] = decoder_config - - if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: - logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " - f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " - "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " - "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " - "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" - ) - - kwargs_decoder["name"] = "decoder" - kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix - decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) - - # Make sure these 2 `keras.Model` have fixed names so `from_pretrained` could load model weights correctly. - if encoder.name != "encoder": - raise ValueError("encoder model must be created with the name `encoder`.") - if decoder.name != "decoder": - raise ValueError("decoder model must be created with the name `decoder`.") - - # instantiate config with corresponding kwargs - config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) - return cls(encoder=encoder, decoder=decoder, config=config) - - @unpack_inputs - @add_start_docstrings_to_model_forward( - VISION_ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length") - ) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: np.ndarray | tf.Tensor | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: tuple | TFBaseModelOutput | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, - labels: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs, - ) -> TFSeq2SeqLMOutput | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, AutoTokenizer, TFVisionEncoderDecoderModel - >>> from PIL import Image - >>> import requests - - >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") - >>> decoder_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - - >>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized - >>> model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained( - ... "google/vit-base-patch16-224-in21k", "openai-community/gpt2" - ... ) - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> img = Image.open(requests.get(url, stream=True).raw) - - >>> # forward - >>> pixel_values = image_processor(images=img, return_tensors="tf").pixel_values # Batch size 1 - >>> decoder_input_ids = decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids # Batch size 1 - >>> outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids) - - >>> # training - >>> outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids) - >>> loss, logits = outputs.loss, outputs.logits - - >>> # save and load from pretrained - >>> model.save_pretrained("vit-gpt2") - >>> model = TFVisionEncoderDecoderModel.from_pretrained("vit-gpt2") - - >>> # generation - >>> generated = model.generate(pixel_values, decoder_start_token_id=model.config.decoder.bos_token_id) - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} - - kwargs_decoder = { - argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") - } - - # Let the user be responsible for the expected format. - if encoder_outputs is not None: - if return_dict and not isinstance(encoder_outputs, ModelOutput): - raise ValueError( - "If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of " - f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`." - ) - - if encoder_outputs is None: - encoder_inputs = { - "input_ids": pixel_values, - "output_attentions": output_attentions, - "output_hidden_states": output_hidden_states, - "return_dict": return_dict, - "training": training, - } - - # Add arguments to encoder from `kwargs_encoder` - encoder_inputs.update(kwargs_encoder) - - if "input_ids" in encoder_inputs: - encoder_inputs["pixel_values"] = encoder_inputs.pop("input_ids") - - if encoder_inputs["pixel_values"] is None: - raise ValueError("You have to specify pixel_values") - - # Handle the case where the inputs are passed as a single dict which contains `labels`. - # The `labels` shouldn't be passed to `self.encoder` below, because it is a based model without this - # parameter (otherwise, an error occurs when `input_processing` is called inside `self.encoder.call()`). - if "labels" in encoder_inputs: - labels = encoder_inputs.pop("labels") - - # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`. - if "decoder_input_ids" in encoder_inputs: - decoder_input_ids = encoder_inputs.pop("decoder_input_ids") - # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`. - if "decoder_attention_mask" in encoder_inputs: - decoder_attention_mask = encoder_inputs.pop("decoder_attention_mask") - - encoder_outputs = self.encoder(**encoder_inputs) - - encoder_hidden_states = encoder_outputs[0] - - # optionally project encoder_hidden_states - if ( - self.encoder.config.hidden_size != self.decoder.config.hidden_size - and self.decoder.config.cross_attention_hidden_size is None - ): - encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) - - if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - batch_size, sequence_length = shape_list(encoder_hidden_states)[:2] - encoder_attention_mask = tf.ones(shape=(batch_size, sequence_length), dtype=tf.int32) - - decoder_inputs = { - "input_ids": decoder_input_ids, - "attention_mask": decoder_attention_mask, - "encoder_hidden_states": encoder_hidden_states, - "encoder_attention_mask": encoder_attention_mask, - "inputs_embeds": decoder_inputs_embeds, - "output_attentions": output_attentions, - "output_hidden_states": output_hidden_states, - "use_cache": use_cache, - "past_key_values": past_key_values, - "return_dict": return_dict, - "training": training, - } - - # Add arguments to decoder from `kwargs_decoder` - decoder_inputs.update(kwargs_decoder) - - decoder_outputs = self.decoder(**decoder_inputs) - - logits = decoder_outputs[0] - - # Compute loss independent from decoder (as some shift the logits inside them) - loss = None - if labels is not None: - warnings.warn(DEPRECATION_WARNING, FutureWarning) - loss = self.hf_compute_loss(labels, logits) - - if not return_dict: - past_key_values = None - if use_cache: - past_key_values = decoder_outputs[1] - # The starting index of the remaining elements in `decoder_outputs` - start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) - - if not isinstance(encoder_outputs, tuple): - encoder_outputs = encoder_outputs.to_tuple() - output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs - output = tuple(x for x in output if x is not None) - return output - - return TFSeq2SeqLMOutput( - loss=loss, - logits=decoder_outputs.logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None - dec_hs = ( - tf.convert_to_tensor(output.decoder_hidden_states) if self.config.decoder.output_hidden_states else None - ) - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.decoder.output_attentions else None - enc_hs = ( - tf.convert_to_tensor(output.encoder_hidden_states) if self.config.encoder.output_hidden_states else None - ) - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.encoder.output_attentions else None - cross_attns = ( - tf.convert_to_tensor(output.cross_attentions) - if self.config.decoder.output_attentions and output.cross_attentions is not None - else None - ) - - return TFSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - cross_attentions=cross_attns, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs - ): - decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) - decoder_attention_mask = decoder_inputs.get("attention_mask", None) - past_key_values = decoder_inputs.get("past_key_values") - input_dict = { - "pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "decoder_input_ids": decoder_inputs["input_ids"], - # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete - "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]), - "past_key_values": past_key_values, - "use_cache": use_cache, - } - return input_dict - - def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - - def resize_token_embeddings(self, *args, **kwargs): - raise NotImplementedError( - "Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported. " - "Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))" - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "enc_to_dec_proj", None) is not None: - with tf.name_scope(self.enc_to_dec_proj.name): - self.enc_to_dec_proj.build([None, None, self.encoder.config.hidden_size]) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -__all__ = ["TFVisionEncoderDecoderModel"] diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py deleted file mode 100644 index 15ef5d48a32d..000000000000 --- a/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py +++ /dev/null @@ -1,601 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax VisionTextDualEncoder model.""" - -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.traverse_util import flatten_dict, unflatten_dict - -from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring -from ...utils import add_start_docstrings, logging -from ..auto.configuration_auto import AutoConfig -from ..auto.modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel -from ..clip.modeling_flax_clip import FlaxCLIPOutput, FlaxCLIPVisionModel -from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "VisionTextDualEncoderConfig" - -VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r""" - This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model - as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded - via the [`~FlaxAutoModel.from_pretrained`] method. The projection layers are automatically added to the model and - should be fine-tuned on a downstream task, like contrastive image-text modeling. - - In [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://huggingface.co/papers/2111.07991) it is shown how - leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvement - on new zero-shot vision tasks such as image classification or retrieval. - - After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other - models (see the examples for more information). - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it - as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`VisionTextDualEncoderConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - - -VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - an image processor (e.g. if you use ViT as the encoder, you should use [`AutoImageProcessor`]). See - [`ViTImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class FlaxVisionTextDualEncoderModule(nn.Module): - config: VisionTextDualEncoderConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - vision_config = self.config.vision_config - text_config = self.config.text_config - - self.vision_embed_dim = vision_config.hidden_size - self.text_embed_dim = text_config.hidden_size - self.projection_dim = self.config.projection_dim - - vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class - text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class - - self.vision_model = vision_module(vision_config, dtype=self.dtype) - self.text_model = text_module(text_config, dtype=self.dtype) - - self.visual_projection = nn.Dense( - self.projection_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(0.02), - use_bias=False, - ) - self.text_projection = nn.Dense( - self.projection_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(0.02), - use_bias=False, - ) - - self.logit_scale = self.param( - "logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, [] - ) - - def __call__( - self, - input_ids=None, - pixel_values=None, - attention_mask=None, - position_ids=None, - token_type_ids=None, - deterministic: bool = True, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - return_dict = return_dict if return_dict is not None else self.config.return_dict - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - text_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - image_embeds = vision_outputs[1] - image_embeds = self.visual_projection(image_embeds) - - text_embeds = text_outputs[1] - text_embeds = self.text_projection(text_embeds) - - # normalized features - image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True) - text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True) - - # cosine similarity as logits - logit_scale = jnp.exp(self.logit_scale) - logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale - logits_per_image = logits_per_text.T - - if not return_dict: - return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - - return FlaxCLIPOutput( - logits_per_image=logits_per_image, - logits_per_text=logits_per_text, - text_embeds=text_embeds, - image_embeds=image_embeds, - text_model_output=text_outputs, - vision_model_output=vision_outputs, - ) - - -@add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING) -class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel): - config_class = VisionTextDualEncoderConfig - module_class = FlaxVisionTextDualEncoderModule - - def __init__( - self, - config: VisionTextDualEncoderConfig, - input_shape: Optional[tuple] = None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - if not _do_init: - raise ValueError( - "`FlaxVisionTextDualEncoderModel` cannot be created without initializing, `_do_init` must be `True`." - ) - - if input_shape is None: - input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3)) - - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensor - input_ids = jnp.zeros(input_shape[0], dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]) - token_type_ids = jnp.ones_like(input_ids) - attention_mask = jnp.ones_like(input_ids) - - pixel_values = jax.random.normal(rng, input_shape[1]) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)[ - "params" - ] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def __call__( - self, - input_ids, - pixel_values, - attention_mask=None, - position_ids=None, - token_type_ids=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) - - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(pixel_values, dtype=jnp.float32), - jnp.array(attention_mask, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - jnp.array(token_type_ids, dtype="i4"), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) - - def get_text_features( - self, - input_ids, - attention_mask=None, - position_ids=None, - token_type_ids=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train=False, - ): - r""" - Args: - input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - - Returns: - text_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The text embeddings obtained by applying - the projection layer to the pooled output of text model. - """ - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic): - text_outputs = module.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - token_type_ids=token_type_ids, - deterministic=deterministic, - ) - pooled_output = text_outputs[1] - text_features = module.text_projection(pooled_output) - return text_features - - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - jnp.array(token_type_ids, dtype="i4"), - not train, - method=_get_features, - rngs=rngs, - ) - - def get_image_features( - self, pixel_values, params: Optional[dict] = None, dropout_rng: jax.random.PRNGKey = None, train=False - ): - r""" - Args: - pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained - using [`ImageFeatureExtractionMixin`]. See [`ImageFeatureExtractionMixin.__call__`] for details. - - Returns: - image_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The image embeddings obtained by - applying the projection layer to the pooled output of vision model. - """ - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _get_features(module, pixel_values, deterministic): - vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic) - pooled_output = vision_outputs[1] # pooled_output - image_features = module.visual_projection(pooled_output) - return image_features - - return self.module.apply( - {"params": params or self.params}, - jnp.array(pixel_values, dtype=jnp.float32), - not train, - method=_get_features, - rngs=rngs, - ) - - @classmethod - def from_vision_text_pretrained( - cls, - vision_model_name_or_path: Optional[str] = None, - text_model_name_or_path: Optional[str] = None, - *model_args, - **kwargs, - ) -> FlaxPreTrainedModel: - """ - Params: - vision_model_name_or_path (`str`, *optional*, defaults to `None`): - Information necessary to initiate the vision model. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *PyTorch checkpoint folder* (e.g, `./pt_model`). In this case, `from_pt` - should be set to `True` and a configuration object should be provided as `config` argument. This - loading path is slower than converting the PyTorch checkpoint in a Flax model using the provided - conversion scripts and loading the Flax model afterwards. - - text_model_name_or_path (`str`, *optional*): - Information necessary to initiate the text model. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *PyTorch checkpoint folder* (e.g, `./pt_model`). In this case, `from_pt` - should be set to `True` and a configuration object should be provided as `config` argument. This - loading path is slower than converting the PyTorch checkpoint in a Flax model using the provided - conversion scripts and loading the Flax model afterwards. - - model_args (remaining positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). - - - To update the text configuration, use the prefix *text_* for each configuration parameter. - - To update the vision configuration, use the prefix *vision_* for each configuration parameter. - - To update the parent model configuration, do not use a prefix for each configuration parameter. - - Behaves differently depending on whether a `config` is provided or automatically loaded. - - Example: - - ```python - >>> from transformers import FlaxVisionTextDualEncoderModel - - >>> # initialize a model from pretrained ViT and BERT models. Note that the projection layers will be randomly initialized. - >>> model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained( - ... "google/vit-base-patch16-224", "google-bert/bert-base-uncased" - ... ) - >>> # saving model after fine-tuning - >>> model.save_pretrained("./vit-bert") - >>> # load fine-tuned model - >>> model = FlaxVisionTextDualEncoderModel.from_pretrained("./vit-bert") - ```""" - - kwargs_vision = { - argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_") - } - - kwargs_text = { - argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_") - } - - # remove text, vision kwargs from kwargs - for key in kwargs_vision: - del kwargs["vision_" + key] - for key in kwargs_text: - del kwargs["text_" + key] - - # Load and initialize the text and vision model - vision_model = kwargs_vision.pop("model", None) - if vision_model is None: - if vision_model_name_or_path is None: - raise ValueError( - "If `vision_model` is not defined as an argument, a `vision_model_name_or_path` has to be defined" - ) - - if "config" not in kwargs_vision: - vision_config = AutoConfig.from_pretrained(vision_model_name_or_path) - - if vision_config.model_type == "clip": - kwargs_vision["config"] = vision_config.vision_config - vision_model = FlaxCLIPVisionModel.from_pretrained( - vision_model_name_or_path, *model_args, **kwargs_vision - ) - else: - kwargs_vision["config"] = vision_config - vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision) - - text_model = kwargs_text.pop("model", None) - if text_model is None: - if text_model_name_or_path is None: - raise ValueError( - "If `text_model` is not defined as an argument, a `text_model_name_or_path` has to be defined" - ) - - if "config" not in kwargs_text: - text_config = AutoConfig.from_pretrained(text_model_name_or_path) - kwargs_text["config"] = text_config - - text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text) - - # instantiate config with corresponding kwargs - dtype = kwargs.pop("dtype", jnp.float32) - config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config, **kwargs) - - # init model - model = cls(config, *model_args, dtype=dtype, **kwargs) - - model.params["vision_model"] = vision_model.params - model.params["text_model"] = text_model.params - - # the projection layers are always newly initialized when loading the model - # using pre-trained vision and text model. - logger.warning( - "The projection layer and logit scale weights `[('visual_projection', 'kernel'), ('text_projection'," - " 'kernel'), ('logit_scale',)]` are newly initialized. You should probably TRAIN this model on a" - " down-stream task to be able to use it for predictions and inference." - ) - - return model - - -VISION_TEXT_DUAL_ENCODER_MODEL_DOCSTRING = r""" - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> import jax - >>> from transformers import ( - ... FlaxVisionTextDualEncoderModel, - ... VisionTextDualEncoderProcessor, - ... AutoImageProcessor, - ... AutoTokenizer, - ... ) - - >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") - >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") - >>> processor = VisionTextDualEncoderProcessor(image_processor, tokenizer) - >>> model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained( - ... "google/vit-base-patch16-224", "google-bert/bert-base-uncased" - ... ) - - >>> # contrastive training - >>> urls = [ - ... "http://images.cocodataset.org/val2017/000000039769.jpg", - ... "https://farm3.staticflickr.com/2674/5850229113_4fe05d5265_z.jpg", - ... ] - >>> images = [Image.open(requests.get(url, stream=True).raw) for url in urls] - >>> inputs = processor( - ... text=["a photo of a cat", "a photo of a dog"], images=images, return_tensors="np", padding=True - ... ) - >>> outputs = model( - ... input_ids=inputs.input_ids, - ... attention_mask=inputs.attention_mask, - ... pixel_values=inputs.pixel_values, - ... ) - >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score - - >>> # save and load from pretrained - >>> model.save_pretrained("vit-bert") - >>> model = FlaxVisionTextDualEncoderModel.from_pretrained("vit-bert") - - >>> # inference - >>> outputs = model(**inputs) - >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score - >>> probs = jax.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities - ``` -""" - -overwrite_call_docstring( - FlaxVisionTextDualEncoderModel, - VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING + VISION_TEXT_DUAL_ENCODER_MODEL_DOCSTRING, -) -append_replace_return_docstrings( - FlaxVisionTextDualEncoderModel, output_type=FlaxCLIPOutput, config_class=_CONFIG_FOR_DOC -) - - -__all__ = ["FlaxVisionTextDualEncoderModel"] diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py deleted file mode 100644 index 42ff0be7a9e8..000000000000 --- a/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py +++ /dev/null @@ -1,623 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow VisionTextDualEncoder model.""" - -from __future__ import annotations - -import re - -import tensorflow as tf - -from ...configuration_utils import PretrainedConfig -from ...modeling_tf_utils import TFPreTrainedModel, keras, unpack_inputs -from ...tf_utils import shape_list -from ...utils import ( - DUMMY_INPUTS, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from ..auto.configuration_auto import AutoConfig -from ..auto.modeling_tf_auto import TFAutoModel -from ..clip.modeling_tf_clip import CLIPVisionConfig, TFCLIPOutput, TFCLIPVisionModel -from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "VisionTextDualEncoderConfig" - -VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r""" - This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model - as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded - via the [`~TFAutoModel.from_pretrained`] method. The projection layers are automatically added to the model and - should be fine-tuned on a downstream task, like contrastive image-text modeling. - - In [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://huggingface.co/papers/2111.07991) it is shown how - leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvement - on new zero-shot vision tasks such as image classification or retrieval. - - After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other - models (see the examples for more information). - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a - regular Keras Model and refer to the TF documentation for all matter related to general usage and behavior. - - Parameters: - config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - an image processor (e.g. if you use ViT as the encoder, you should use [`AutoImageProcessor`]). See - [`ViTImageProcessor.__call__`] for details. - return_loss (`bool`, *optional*): - Whether or not to return the contrastive loss. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.clip.modeling_tf_clip.contrastive_loss -def contrastive_loss(logits: tf.Tensor) -> tf.Tensor: - return tf.math.reduce_mean( - keras.metrics.sparse_categorical_crossentropy( - y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True - ) - ) - - -# Copied from transformers.models.clip.modeling_tf_clip.clip_loss -def clip_loss(similarity: tf.Tensor) -> tf.Tensor: - caption_loss = contrastive_loss(similarity) - image_loss = contrastive_loss(tf.transpose(similarity)) - return (caption_loss + image_loss) / 2.0 - - -@add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING) -class TFVisionTextDualEncoderModel(TFPreTrainedModel): - config_class = VisionTextDualEncoderConfig - base_model_prefix = "vision_text_dual_encoder" - load_weight_prefix = "tf_vision_text_dual_encoder_model" - - def __init__( - self, - config: VisionTextDualEncoderConfig | None = None, - vision_model: TFPreTrainedModel | None = None, - text_model: TFPreTrainedModel | None = None, - ): - if config is None and (vision_model is None or text_model is None): - raise ValueError("Either a configuration or an vision and a text model has to be provided") - - if config is None: - config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config) - else: - if not isinstance(config, self.config_class): - raise ValueError(f"config: {config} has to be of type {self.config_class}") - - # initialize with config - super().__init__(config) - - if vision_model is None: - if isinstance(config.vision_config, CLIPVisionConfig): - vision_model = TFCLIPVisionModel.from_config(config.vision_config, name="vision_model") - else: - vision_model = TFAutoModel.from_config(config.vision_config, name="vision_model") - - if text_model is None: - text_model = TFAutoModel.from_config(config.text_config, name="text_model") - - self.vision_model = vision_model - self.text_model = text_model - - # make sure that the individual model's config refers to the shared config - # so that the updates to the config will be synced - self.vision_model.config = self.config.vision_config - self.text_model.config = self.config.text_config - - self.vision_embed_dim = config.vision_config.hidden_size - self.text_embed_dim = config.text_config.hidden_size - self.projection_dim = config.projection_dim - - self.visual_projection = keras.layers.Dense(self.projection_dim, use_bias=False, name="visual_projection") - self.text_projection = keras.layers.Dense(self.projection_dim, use_bias=False, name="text_projection") - self.logit_scale = None - self.config = config - - def build(self, input_shape=None): - if self.built: - return - self.built = True - # Build in the build() method to make sure the names are right - initializer = keras.initializers.Constant(self.config.logit_scale_init_value) - self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale") - - if getattr(self, "visual_projection", None) is not None: - with tf.name_scope(self.visual_projection.name): - self.visual_projection.build([None, None, self.vision_embed_dim]) - if getattr(self, "text_projection", None) is not None: - with tf.name_scope(self.text_projection.name): - self.text_projection.build([None, None, self.text_embed_dim]) - with tf.name_scope(self.vision_model.name): - self.vision_model.build(None) - with tf.name_scope(self.text_model.name): - self.text_model.build(None) - - def tf_to_pt_weight_rename(self, tf_weight): - # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models - # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. - # However, the name of that extra layer is the name of the MainLayer in the base model. - if "vision_model" in tf_weight: - if tf_weight.count("vision_model") == 1: - return (re.sub(r"vision_model\..*?\.", "vision_model.", tf_weight),) - elif tf_weight.count("vision_model") == 2: - return (re.sub(r"vision_model\..*?\.vision_model", "vision_model.vision_model", tf_weight),) - else: - raise ValueError( - f"Unexpected weight name {tf_weight}. Please file an issue on the" - " Transformers repo to let us know about this error!" - ) - elif "text_model" in tf_weight: - return (re.sub(r"text_model\..*?\.", "text_model.", tf_weight),) - else: - return (tf_weight,) - - @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING) - def get_text_features( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - token_type_ids=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - Returns: - text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying - the projection layer to the pooled output of [`TFCLIPTextModel`]. - - Examples: - - ```python - >>> from transformers import TFVisionTextDualEncoderModel, AutoTokenizer - - >>> model = TFVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", from_pt=True) - >>> tokenizer = AutoTokenizer.from_pretrained("clip-italian/clip-italian") - - >>> inputs = tokenizer(["una foto di un gatto", "una foto di un cane"], padding=True, return_tensors="np") - >>> text_features = model.get_text_features(**inputs) - ```""" - text_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - token_type_ids=token_type_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = text_outputs[1] - text_features = self.text_projection(pooled_output) - - return text_features - - @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING) - def get_image_features( - self, - pixel_values=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - Returns: - image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying - the projection layer to the pooled output of [`TFCLIPVisionModel`]. - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import TFVisionTextDualEncoderModel, AutoImageProcessor - - >>> model = TFVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", from_pt=True) - >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = image_processor(images=image, return_tensors="np") - - >>> image_features = model.get_image_features(**inputs) - ```""" - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = vision_outputs[1] # pooled_output - image_features = self.visual_projection(pooled_output) - - return image_features - - @unpack_inputs - @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFCLIPOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: tf.Tensor | None = None, - pixel_values: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - return_loss: bool | None = None, - token_type_ids: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFCLIPOutput: - r""" - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import ( - ... TFVisionTextDualEncoderModel, - ... VisionTextDualEncoderProcessor, - ... AutoImageProcessor, - ... AutoTokenizer, - ... ) - - >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") - >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") - >>> processor = VisionTextDualEncoderProcessor(image_processor, tokenizer) - >>> model = TFVisionTextDualEncoderModel.from_vision_text_pretrained( - ... "google/vit-base-patch16-224", "google-bert/bert-base-uncased" - ... ) - - >>> # contrastive training - >>> urls = [ - ... "http://images.cocodataset.org/val2017/000000039769.jpg", - ... "https://farm3.staticflickr.com/2674/5850229113_4fe05d5265_z.jpg", - ... ] - >>> images = [Image.open(requests.get(url, stream=True).raw) for url in urls] - >>> inputs = processor( - ... text=["a photo of a cat", "a photo of a dog"], images=images, return_tensors="np", padding=True - ... ) - >>> outputs = model( - ... input_ids=inputs.input_ids, - ... attention_mask=inputs.attention_mask, - ... pixel_values=inputs.pixel_values, - ... return_loss=True, - ... ) - >>> loss, logits_per_image = outputs.loss, outputs.logits_per_image # this is the image-text similarity score - - >>> # save and load from pretrained - >>> model.save_pretrained("vit-bert") - >>> model = TFVisionTextDualEncoderModel.from_pretrained("vit-bert") - - >>> # inference - >>> outputs = model(**inputs) - >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score - >>> probs = tf.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities - ```""" - return_dict = return_dict if return_dict is not None else self.config.return_dict - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - text_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - image_embeds = vision_outputs[1] # pooler_output - image_embeds = self.visual_projection(image_embeds) - - text_embeds = text_outputs[1] # pooler_output - text_embeds = self.text_projection(text_embeds) - - # normalized features - image_embeds = image_embeds / tf.norm(image_embeds, axis=-1, keepdims=True) - text_embeds = text_embeds / tf.norm(text_embeds, axis=-1, keepdims=True) - - # cosine similarity as logits - logit_scale = tf.math.exp(self.logit_scale) - logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale - logits_per_image = tf.transpose(logits_per_text) - - loss = None - if return_loss: - loss = clip_loss(logits_per_text) - if loss.shape.rank == 0: - loss = tf.expand_dims(loss, 0) - - if not return_dict: - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output - - return TFCLIPOutput( - loss=loss, - logits_per_image=logits_per_image, - logits_per_text=logits_per_text, - text_embeds=text_embeds, - image_embeds=image_embeds, - text_model_output=text_outputs, - vision_model_output=vision_outputs, - ) - - @classmethod - def from_vision_text_pretrained( - cls, - vision_model_name_or_path: str | None = None, - text_model_name_or_path: str | None = None, - *model_args, - **kwargs, - ) -> TFPreTrainedModel: - """ - Params: - vision_model_name_or_path (`str`, *optional*, defaults to `None`): - Information necessary to initiate the vision model. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *PyTorch checkpoint folder* (e.g, `./pt_model`). In this case, `from_pt` - should be set to `True` and a configuration object should be provided as `config` argument. - - text_model_name_or_path (`str`, *optional*): - Information necessary to initiate the text model. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *PyTorch checkpoint folder* (e.g, `./pt_model`). In this case, `from_pt` - should be set to `True` and a configuration object should be provided as `config` argument. - - model_args (remaining positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). - - - To update the text configuration, use the prefix *text_* for each configuration parameter. - - To update the vision configuration, use the prefix *vision_* for each configuration parameter. - - To update the parent model configuration, do not use a prefix for each configuration parameter. - - Behaves differently depending on whether a `config` is provided or automatically loaded. - - Example: - - ```python - >>> from transformers import TFVisionTextDualEncoderModel - - >>> # initialize a model from pretrained ViT and BERT models. Note that the projection layers will be randomly initialized. - >>> model = TFVisionTextDualEncoderModel.from_vision_text_pretrained( - ... "google/vit-base-patch16-224", "google-bert/bert-base-uncased" - ... ) - >>> # saving model after fine-tuning - >>> model.save_pretrained("./vit-bert") - >>> # load fine-tuned model - >>> model = TFVisionTextDualEncoderModel.from_pretrained("./vit-bert") - ```""" - kwargs_vision = { - argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_") - } - - kwargs_text = { - argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_") - } - - # remove vision, text kwargs from kwargs - for key in kwargs_vision: - del kwargs["vision_" + key] - for key in kwargs_text: - del kwargs["text_" + key] - - # Load and initialize the vision and text model - vision_model = kwargs_vision.pop("model", None) - if vision_model is None: - if vision_model_name_or_path is None: - raise ValueError( - "If `vision_model` is not defined as an argument, a `vision_model_name_or_path` has to be defined" - ) - kwargs_vision["name"] = "vision_model" - kwargs_vision["load_weight_prefix"] = cls.load_weight_prefix - - vision_config_dict, unused_args = PretrainedConfig.get_config_dict(vision_model_name_or_path, **kwargs) - if vision_config_dict.get("model_type", None) == "clip_vision_model": - vision_config = CLIPVisionConfig.from_dict(vision_config_dict) - else: - vision_config = AutoConfig.from_pretrained(vision_model_name_or_path) - - if vision_config.model_type == "clip_vision_model": - kwargs_vision["config"] = vision_config - vision_class = TFCLIPVisionModel - elif vision_config.model_type == "clip": - kwargs_vision["config"] = vision_config.vision_config - vision_class = TFCLIPVisionModel - else: - kwargs_vision["config"] = vision_config - vision_class = TFAutoModel - vision_model = vision_class.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision) - - text_model = kwargs_text.pop("model", None) - if text_model is None: - if text_model_name_or_path is None: - raise ValueError( - "If `text_model` is not defined as an argument, a `text_model_name_or_path` has to be defined" - ) - kwargs_text["name"] = "text_model" - kwargs_text["load_weight_prefix"] = cls.load_weight_prefix - - if "config" not in kwargs_text: - text_config = AutoConfig.from_pretrained(text_model_name_or_path) - kwargs_text["config"] = text_config - - text_model = TFAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text) - - # instantiate config with corresponding kwargs - config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config, **kwargs) - - # init model - model = cls(config=config, vision_model=vision_model, text_model=text_model) - - # the projection layers are always newly initialized when loading the model - # using pre-trained vision and text model. - logger.warning( - "The projection layer and logit scale weights `['visual_projection.weight', 'text_projection.weight'," - " 'logit_scale']` are newly initialized. You should probably TRAIN this model on a down-stream task to be" - " able to use it for predictions and inference." - ) - - if vision_model.name != "vision_model": - raise ValueError("vision model must be created with the name `vision_model`.") - if text_model.name != "text_model": - raise ValueError("text model must be created with the name `text_model`.") - - model.build_in_name_scope() # Ensure model is fully built - - return model - - @property - def dummy_inputs(self): - """ - Dummy inputs to build the network. - - Returns: - `dict[str, tf.Tensor]`: The dummy inputs. - """ - input_ids = tf.constant(DUMMY_INPUTS, dtype=tf.int32) - batch_size, seq_len = input_ids.shape - - VISION_DUMMY_INPUTS = tf.random.uniform( - shape=( - batch_size, - self.config.vision_config.num_channels, - self.config.vision_config.image_size, - self.config.vision_config.image_size, - ), - dtype=tf.float32, - ) - pixel_values = tf.constant(VISION_DUMMY_INPUTS) - dummy = {"pixel_values": pixel_values, "input_ids": input_ids} - return dummy - - -__all__ = ["TFVisionTextDualEncoderModel"] diff --git a/src/transformers/models/vit/modeling_flax_vit.py b/src/transformers/models/vit/modeling_flax_vit.py deleted file mode 100644 index d62ef1b6b928..000000000000 --- a/src/transformers/models/vit/modeling_flax_vit.py +++ /dev/null @@ -1,677 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward -from .configuration_vit import ViTConfig - - -VIT_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`ViTConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -VIT_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] - for details. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class FlaxViTPatchEmbeddings(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - image_size = self.config.image_size - patch_size = self.config.patch_size - num_patches = (image_size // patch_size) * (image_size // patch_size) - self.num_patches = num_patches - self.num_channels = self.config.num_channels - self.projection = nn.Conv( - self.config.hidden_size, - kernel_size=(patch_size, patch_size), - strides=(patch_size, patch_size), - padding="VALID", - dtype=self.dtype, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - ) - - def __call__(self, pixel_values): - num_channels = pixel_values.shape[-1] - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - embeddings = self.projection(pixel_values) - batch_size, _, _, channels = embeddings.shape - return jnp.reshape(embeddings, (batch_size, -1, channels)) - - -class FlaxViTEmbeddings(nn.Module): - """Construct the CLS token, position and patch embeddings.""" - - config: ViTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.cls_token = self.param( - "cls_token", - jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), - (1, 1, self.config.hidden_size), - ) - self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype) - num_patches = self.patch_embeddings.num_patches - self.position_embeddings = self.param( - "position_embeddings", - jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), - (1, num_patches + 1, self.config.hidden_size), - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, pixel_values, deterministic=True): - batch_size = pixel_values.shape[0] - - embeddings = self.patch_embeddings(pixel_values) - - cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) - embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) - embeddings = embeddings + self.position_embeddings - embeddings = self.dropout(embeddings, deterministic=deterministic) - return embeddings - - -class FlaxViTSelfAttention(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - if self.config.hidden_size % self.config.num_attention_heads != 0: - raise ValueError( - "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:" - " {self.config.num_attention_heads}" - ) - - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" - ), - use_bias=self.config.qkv_bias, - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" - ), - use_bias=self.config.qkv_bias, - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" - ), - use_bias=self.config.qkv_bias, - ) - - def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): - head_dim = self.config.hidden_size // self.config.num_attention_heads - - query_states = self.query(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - value_states = self.value(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - key_states = self.key(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -class FlaxViTSelfOutput(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, input_tensor, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxViTAttention(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.attention = FlaxViTSelfAttention(self.config, dtype=self.dtype) - self.output = FlaxViTSelfOutput(self.config, dtype=self.dtype) - - def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False): - attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions) - attn_output = attn_outputs[0] - hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_outputs[1],) - - return outputs - - -class FlaxViTIntermediate(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - dtype=self.dtype, - ) - self.activation = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -class FlaxViTOutput(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, attention_output, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = hidden_states + attention_output - return hidden_states - - -class FlaxViTLayer(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.attention = FlaxViTAttention(self.config, dtype=self.dtype) - self.intermediate = FlaxViTIntermediate(self.config, dtype=self.dtype) - self.output = FlaxViTOutput(self.config, dtype=self.dtype) - self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): - attention_outputs = self.attention( - self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention - deterministic=deterministic, - output_attentions=output_attentions, - ) - - attention_output = attention_outputs[0] - - # first residual connection - attention_output = attention_output + hidden_states - - # in ViT, layernorm is also applied after self-attention - layer_output = self.layernorm_after(attention_output) - - hidden_states = self.intermediate(layer_output) - hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[1],) - return outputs - - -class FlaxViTLayerCollection(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxViTLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states,) - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -class FlaxViTEncoder(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layer = FlaxViTLayerCollection(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.layer( - hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class FlaxViTPooler(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.pooler_output_size, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - dtype=self.dtype, - ) - self.activation = ACT2FN[self.config.pooler_act] - - def __call__(self, hidden_states): - cls_hidden_state = hidden_states[:, 0] - cls_hidden_state = self.dense(cls_hidden_state) - return self.activation(cls_hidden_state) - - -class FlaxViTPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ViTConfig - base_model_prefix = "vit" - main_input_name = "pixel_values" - module_class: nn.Module = None - - def __init__( - self, - config: ViTConfig, - input_shape=None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - if input_shape is None: - input_shape = (1, config.image_size, config.image_size, config.num_channels) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - pixel_values = jnp.zeros(input_shape, dtype=self.dtype) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - pixel_values, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - return self.module.apply( - {"params": params or self.params}, - jnp.array(pixel_values, dtype=jnp.float32), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) - - -class FlaxViTModule(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - add_pooling_layer: bool = True - - def setup(self): - self.embeddings = FlaxViTEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxViTEncoder(self.config, dtype=self.dtype) - self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.pooler = FlaxViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None - - def __call__( - self, - pixel_values, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - hidden_states = self.embeddings(pixel_values, deterministic=deterministic) - - outputs = self.encoder( - hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - hidden_states = self.layernorm(hidden_states) - pooled = self.pooler(hidden_states) if self.add_pooling_layer else None - - if not return_dict: - # if pooled is None, don't return it - if pooled is None: - return (hidden_states,) + outputs[1:] - return (hidden_states, pooled) + outputs[1:] - - return FlaxBaseModelOutputWithPooling( - last_hidden_state=hidden_states, - pooler_output=pooled, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.", - VIT_START_DOCSTRING, -) -class FlaxViTModel(FlaxViTPreTrainedModel): - module_class = FlaxViTModule - - -FLAX_VISION_MODEL_DOCSTRING = """ - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, FlaxViTModel - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") - >>> model = FlaxViTModel.from_pretrained("google/vit-base-patch16-224-in21k") - - >>> inputs = image_processor(images=image, return_tensors="np") - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.last_hidden_state - ``` -""" - -overwrite_call_docstring(FlaxViTModel, FLAX_VISION_MODEL_DOCSTRING) -append_replace_return_docstrings(FlaxViTModel, output_type=FlaxBaseModelOutputWithPooling, config_class=ViTConfig) - - -class FlaxViTForImageClassificationModule(nn.Module): - config: ViTConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.vit = FlaxViTModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) - self.classifier = nn.Dense( - self.config.num_labels, - dtype=self.dtype, - kernel_init=jax.nn.initializers.variance_scaling( - self.config.initializer_range**2, "fan_in", "truncated_normal" - ), - ) - - def __call__( - self, - pixel_values=None, - deterministic: bool = True, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.vit( - pixel_values, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.classifier(hidden_states[:, 0, :]) - - if not return_dict: - output = (logits,) + outputs[2:] - return output - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of - the [CLS] token) e.g. for ImageNet. - """, - VIT_START_DOCSTRING, -) -class FlaxViTForImageClassification(FlaxViTPreTrainedModel): - module_class = FlaxViTForImageClassificationModule - - -FLAX_VISION_CLASSIF_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoImageProcessor, FlaxViTForImageClassification - >>> from PIL import Image - >>> import jax - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") - >>> model = FlaxViTForImageClassification.from_pretrained("google/vit-base-patch16-224") - - >>> inputs = image_processor(images=image, return_tensors="np") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) - >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) - ``` -""" - -overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) -append_replace_return_docstrings( - FlaxViTForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=ViTConfig -) - - -__all__ = ["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"] diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py deleted file mode 100644 index 80d785e32114..000000000000 --- a/src/transformers/models/vit/modeling_tf_vit.py +++ /dev/null @@ -1,906 +0,0 @@ -# coding=utf-8 -# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 ViT model.""" - -from __future__ import annotations - -import collections.abc -import math - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput -from ...modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list, stable_softmax -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_vit import ViTConfig - - -logger = logging.get_logger(__name__) - -# General docstring -_CONFIG_FOR_DOC = "ViTConfig" - -# Base docstring -_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k" -_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] - -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224" -_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" - - -class TFViTEmbeddings(keras.layers.Layer): - """ - Construct the CLS token, position and patch embeddings. - - """ - - def __init__(self, config: ViTConfig, **kwargs): - super().__init__(**kwargs) - - self.patch_embeddings = TFViTPatchEmbeddings(config, name="patch_embeddings") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def build(self, input_shape=None): - num_patches = self.patch_embeddings.num_patches - self.cls_token = self.add_weight( - shape=(1, 1, self.config.hidden_size), - initializer=get_initializer(self.config.initializer_range), - trainable=True, - name="cls_token", - ) - self.position_embeddings = self.add_weight( - shape=(1, num_patches + 1, self.config.hidden_size), - initializer=get_initializer(self.config.initializer_range), - trainable=True, - name="position_embeddings", - ) - - if self.built: - return - self.built = True - if getattr(self, "patch_embeddings", None) is not None: - with tf.name_scope(self.patch_embeddings.name): - self.patch_embeddings.build(None) - - def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. - - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 - """ - - batch_size, seq_len, dim = shape_list(embeddings) - num_patches = seq_len - 1 - - _, num_positions, _ = shape_list(self.position_embeddings) - num_positions -= 1 - - if num_patches == num_positions and height == width: - return self.position_embeddings - class_pos_embed = self.position_embeddings[:, :1] - patch_pos_embed = self.position_embeddings[:, 1:] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - patch_pos_embed = tf.image.resize( - images=tf.reshape( - patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) - ), - size=(h0, w0), - method="bicubic", - ) - - shape = shape_list(patch_pos_embed) - assert h0 == shape[-3] and w0 == shape[-2] - patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) - return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1) - - def call( - self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False - ) -> tf.Tensor: - batch_size, num_channels, height, width = shape_list(pixel_values) - embeddings = self.patch_embeddings( - pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, training=training - ) - - # add the [CLS] token to the embedded patch tokens - cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0) - embeddings = tf.concat((cls_tokens, embeddings), axis=1) - - # add positional encoding to each token - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - embeddings = embeddings + self.position_embeddings - - embeddings = self.dropout(embeddings, training=training) - - return embeddings - - -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -class TFViTPatchEmbeddings(keras.layers.Layer): - """ - This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial - `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a - Transformer. - """ - - def __init__(self, config: ViTConfig, **kwargs): - super().__init__(**kwargs) - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.hidden_size - - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_patches = num_patches - self.num_channels = num_channels - self.config = config - - self.projection = keras.layers.Conv2D( - filters=hidden_size, - kernel_size=patch_size, - strides=patch_size, - padding="valid", - data_format="channels_last", - use_bias=True, - kernel_initializer=get_initializer(self.config.initializer_range), - bias_initializer="zeros", - name="projection", - ) - - def call( - self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False - ) -> tf.Tensor: - batch_size, num_channels, height, width = shape_list(pixel_values) - if tf.executing_eagerly() and num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - if not interpolate_pos_encoding: - if tf.executing_eagerly(): - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) - - # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - # shape = (batch_size, in_height, in_width, in_channels=num_channels) - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - - projection = self.projection(pixel_values) - - # Change the 2D spatial dimensions to a single temporal dimension. - # shape = (batch_size, num_patches, out_channels=embed_dim) - num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0]) - embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1)) - - return embeddings - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, None, self.num_channels]) - - -class TFViTSelfAttention(keras.layers.Layer): - def __init__(self, config: ViTConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - mixed_key_layer = self.key(inputs=hidden_states) - mixed_value_layer = self.value(inputs=hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -class TFViTSelfOutput(keras.layers.Layer): - """ - The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the - layernorm applied before each block. - """ - - def __init__(self, config: ViTConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFViTAttention(keras.layers.Layer): - def __init__(self, config: ViTConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFViTSelfAttention(config, name="attention") - self.dense_output = TFViTSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -class TFViTIntermediate(keras.layers.Layer): - def __init__(self, config: ViTConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -class TFViTOutput(keras.layers.Layer): - def __init__(self, config: ViTConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = hidden_states + input_tensor - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - - -class TFViTLayer(keras.layers.Layer): - """This corresponds to the Block class in the timm implementation.""" - - def __init__(self, config: ViTConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFViTAttention(config, name="attention") - self.intermediate = TFViTIntermediate(config, name="intermediate") - self.vit_output = TFViTOutput(config, name="output") - - self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") - self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - attention_outputs = self.attention( - # in ViT, layernorm is applied before self-attention - input_tensor=self.layernorm_before(inputs=hidden_states), - head_mask=head_mask, - output_attentions=output_attentions, - training=training, - ) - attention_output = attention_outputs[0] - - # first residual connection - hidden_states = attention_output + hidden_states - - # in ViT, layernorm is also applied after self-attention - layer_output = self.layernorm_after(inputs=hidden_states) - - intermediate_output = self.intermediate(hidden_states=layer_output) - - # second residual connection is done here - layer_output = self.vit_output( - hidden_states=intermediate_output, input_tensor=hidden_states, training=training - ) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "vit_output", None) is not None: - with tf.name_scope(self.vit_output.name): - self.vit_output.build(None) - if getattr(self, "layernorm_before", None) is not None: - with tf.name_scope(self.layernorm_before.name): - self.layernorm_before.build([None, None, self.config.hidden_size]) - if getattr(self, "layernorm_after", None) is not None: - with tf.name_scope(self.layernorm_after.name): - self.layernorm_after.build([None, None, self.config.hidden_size]) - - -class TFViTEncoder(keras.layers.Layer): - def __init__(self, config: ViTConfig, **kwargs): - super().__init__(**kwargs) - - self.layer = [TFViTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states=hidden_states, - head_mask=head_mask[i], - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFViTMainLayer(keras.layers.Layer): - config_class = ViTConfig - - def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, **kwargs): - super().__init__(**kwargs) - - self.config = config - - self.embeddings = TFViTEmbeddings(config, name="embeddings") - self.encoder = TFViTEncoder(config, name="encoder") - self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - self.pooler = TFViTPooler(config, name="pooler") if add_pooling_layer else None - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings.patch_embeddings - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - pixel_values: TFModelInputType | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - embedding_output = self.embeddings( - pixel_values=pixel_values, - interpolate_pos_encoding=interpolate_pos_encoding, - training=training, - ) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(inputs=sequence_output) - pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return TFBaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, self.config.hidden_size]) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - - -class TFViTPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ViTConfig - base_model_prefix = "vit" - main_input_name = "pixel_values" - - -VIT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`ViTConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -VIT_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] - for details. - - head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - interpolate_pos_encoding (`bool`, *optional*): - Whether to interpolate the pre-trained position encodings. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.", - VIT_START_DOCSTRING, -) -class TFViTModel(TFViTPreTrainedModel): - def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit") - - @unpack_inputs - @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPooling, - config_class=_CONFIG_FOR_DOC, - modality="vision", - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - def call( - self, - pixel_values: TFModelInputType | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: - outputs = self.vit( - pixel_values=pixel_values, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "vit", None) is not None: - with tf.name_scope(self.vit.name): - self.vit.build(None) - - -class TFViTPooler(keras.layers.Layer): - def __init__(self, config: ViTConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.pooler_output_size, - kernel_initializer=get_initializer(config.initializer_range), - activation=config.pooler_act, - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of - the [CLS] token) e.g. for ImageNet. - - - - Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by - setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained - position embeddings to the higher resolution. - - - """, - VIT_START_DOCSTRING, -) -class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config: ViTConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.num_labels = config.num_labels - self.vit = TFViTMainLayer(config, add_pooling_layer=False, name="vit") - - # Classifier head - self.classifier = keras.layers.Dense( - units=config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="classifier", - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, - ) - def call( - self, - pixel_values: TFModelInputType | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - - outputs = self.vit( - pixel_values=pixel_values, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.classifier(inputs=sequence_output[:, 0, :]) - loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "vit", None) is not None: - with tf.name_scope(self.vit.name): - self.vit.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -__all__ = ["TFViTForImageClassification", "TFViTModel", "TFViTPreTrainedModel"] diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py deleted file mode 100644 index d0184e92b37b..000000000000 --- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py +++ /dev/null @@ -1,1374 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 ViT MAE (masked autoencoder) model.""" - -from __future__ import annotations - -import collections.abc -import math -from copy import deepcopy -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...file_utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) -from ...modeling_tf_outputs import TFBaseModelOutput -from ...modeling_tf_utils import ( - TFModelInputType, - TFPreTrainedModel, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list, stable_softmax -from ...utils import logging -from .configuration_vit_mae import ViTMAEConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "ViTMAEConfig" -_CHECKPOINT_FOR_DOC = "facebook/vit-mae-base" - - -@dataclass -class TFViTMAEModelOutput(ModelOutput): - """ - Class for TFViTMAEModel's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - mask (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Tensor indicating which patches are masked (1) and which are not (0). - ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Tensor containing the original index of the (shuffled) masked patches. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus - the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - """ - - last_hidden_state: tf.Tensor | None = None - mask: tf.Tensor | None = None - ids_restore: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFViTMAEDecoderOutput(ModelOutput): - """ - Class for TFViTMAEDecoder's outputs, with potential hidden states and attentions. - - Args: - logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`): - Pixel reconstruction logits. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus - the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - """ - - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -@dataclass -class TFViTMAEForPreTrainingOutput(ModelOutput): - """ - Class for TFViTMAEForPreTraining's outputs, with potential hidden states and attentions. - - Args: - loss (`tf.Tensor` of shape `(1,)`): - Pixel reconstruction loss. - logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`): - Pixel reconstruction logits. - mask (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Tensor indicating which patches are masked (1) and which are not (0). - ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Tensor containing the original index of the (shuffled) masked patches. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus - the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - mask: tf.Tensor | None = None - ids_restore: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): - """ - Create 2D sin/cos positional embeddings. - - Args: - embed_dim (`int`): - Embedding dimension. - grid_size (`int`): - The grid height and width. - add_cls_token (`bool`, *optional*, defaults to `False`): - Whether or not to add a classification (CLS) token. - - Returns: - (`tf.Tensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the position - embeddings (with or without classification token) - """ - grid_h = tf.range(grid_size, dtype=tf.float32) - grid_w = tf.range(grid_size, dtype=tf.float32) - grid = tf.meshgrid(grid_w, grid_h) # here w goes first - grid = tf.stack(grid, axis=0) - - grid = tf.reshape(grid, [2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if add_cls_token: - pos_embed = tf.concat([tf.zeros((1, embed_dim)), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be even") - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = tf.concat([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be even") - - omega = tf.range(embed_dim // 2, dtype="float32") - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = tf.reshape(pos, [-1]) # (M,) - out = tf.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - # half of the positions get sinusoidal pattern and the rest gets - # cosine pattern and then they are concatenated - emb_sin = tf.sin(out) # (M, D/2) - emb_cos = tf.cos(out) # (M, D/2) - - emb = tf.concat([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -class TFViTMAEEmbeddings(keras.layers.Layer): - """ - Construct the CLS token, position and patch embeddings. - - """ - - def __init__(self, config: ViTMAEConfig, **kwargs): - super().__init__(**kwargs) - - self.patch_embeddings = TFViTMAEPatchEmbeddings(config, name="patch_embeddings") - self.num_patches = self.patch_embeddings.num_patches - - self.config = config - - def build(self, input_shape=None): - self.cls_token = self.add_weight( - shape=(1, 1, self.config.hidden_size), - initializer=tf.random_normal_initializer(stddev=self.config.initializer_range), - trainable=True, - name="cls_token", - ) - self.position_embeddings = self.add_weight( - shape=(1, self.num_patches + 1, self.config.hidden_size), - initializer="zeros", - trainable=False, # fixed sin-cos embedding - name="position_embeddings", - ) - pos_embed = get_2d_sincos_pos_embed( - self.position_embeddings.shape[-1], - int(self.patch_embeddings.num_patches**0.5), - add_cls_token=True, - )[None, ...] - self.position_embeddings.assign(pos_embed) - - if self.built: - return - self.built = True - if getattr(self, "patch_embeddings", None) is not None: - with tf.name_scope(self.patch_embeddings.name): - self.patch_embeddings.build(None) - - def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. - - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 - """ - - batch_size, seq_len, dim = shape_list(embeddings) - num_patches = seq_len - 1 - - _, num_positions, _ = shape_list(self.position_embeddings) - num_positions -= 1 - - if num_patches == num_positions and height == width: - return self.position_embeddings - class_pos_embed = self.position_embeddings[:, :1] - patch_pos_embed = self.position_embeddings[:, 1:] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - patch_pos_embed = tf.image.resize( - images=tf.reshape( - patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) - ), - size=(h0, w0), - method="bicubic", - ) - - patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) - return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1) - - def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None): - """ - Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random - noise. - - Args: - sequence (`tf.Tensor` of shape `(batch_size, sequence_length, dim)`) - noise (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*) which is - mainly used for testing purposes to control randomness and maintain the reproducibility - """ - batch_size, seq_length, dim = shape_list(sequence) - len_keep = int(seq_length * (1 - self.config.mask_ratio)) - - if noise is None: - noise = tf.random.uniform(shape=(batch_size, seq_length), minval=0.0, maxval=1.0) # noise in [0, 1) - - # sort noise for each sample - ids_shuffle = tf.argsort(noise, axis=1) # ascend: small is keep, large is remove - ids_restore = tf.argsort(ids_shuffle, axis=1) - - # keep the first subset - ids_keep = ids_shuffle[:, :len_keep] - sequence_unmasked = tf.gather( - sequence, - axis=1, - batch_dims=1, - indices=ids_keep, - ) - - # generate the binary mask: 0 is keep, 1 is remove - # this hack is needed because TF's EagerTensors don't support - # assignment - mask_keep = tf.zeros((batch_size, len_keep)) - mask_remove = tf.ones((batch_size, seq_length - len_keep)) - mask = tf.concat([mask_keep, mask_remove], axis=-1) - - # unshuffle to get the binary mask - mask = tf.gather(mask, axis=1, batch_dims=1, indices=ids_restore) - - return sequence_unmasked, mask, ids_restore - - def call( - self, pixel_values: tf.Tensor, noise: tf.Tensor | None = None, interpolate_pos_encoding: bool = False - ) -> tf.Tensor: - batch_size, num_channels, height, width = shape_list(pixel_values) - embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - if interpolate_pos_encoding: - position_embeddings = self.interpolate_pos_encoding(embeddings, height, width) - else: - position_embeddings = self.position_embeddings - # add position embeddings w/o cls token - embeddings = embeddings + position_embeddings[:, 1:, :] - - # masking: length -> length * config.mask_ratio - embeddings, mask, ids_restore = self.random_masking(embeddings, noise) - - # append cls token - cls_token = self.cls_token + position_embeddings[:, :1, :] - cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1)) - embeddings = tf.concat([cls_tokens, embeddings], axis=1) - - return embeddings, mask, ids_restore - - -class TFViTMAEPatchEmbeddings(keras.layers.Layer): - """ - This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial - `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a - Transformer. - """ - - def __init__(self, config: ViTMAEConfig, **kwargs): - super().__init__(**kwargs) - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.hidden_size - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_patches = num_patches - self.num_channels = num_channels - self.config = config - - self.projection = keras.layers.Conv2D( - filters=hidden_size, - kernel_size=patch_size, - strides=patch_size, - padding="valid", - data_format="channels_last", - kernel_initializer="glorot_uniform", # following torch.nn.Linear - bias_initializer="zeros", - name="projection", - ) - - def call( - self, pixel_values: tf.Tensor, training: bool = False, interpolate_pos_encoding: bool = False - ) -> tf.Tensor: - batch_size, num_channels, height, width = shape_list(pixel_values) - if tf.executing_eagerly(): - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the" - " configuration." - ) - if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) - - # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. - # So change the input format from `NCHW` to `NHWC`. - # shape = (batch_size, in_height, in_width, in_channels=num_channels) - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - - projection = self.projection(pixel_values) - - # Change the 2D spatial dimensions to a single temporal dimension. - # shape = (batch_size, num_patches, out_channels=embed_dim) - num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0]) - x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1)) - - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, None, self.num_channels]) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfAttention with ViT->ViTMAE -class TFViTMAESelfAttention(keras.layers.Layer): - def __init__(self, config: ViTMAEConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - mixed_key_layer = self.key(inputs=hidden_states) - mixed_value_layer = self.value(inputs=hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->ViTMAE -class TFViTMAESelfOutput(keras.layers.Layer): - """ - The residual connection is defined in TFViTMAELayer instead of here (as is the case with other models), due to the - layernorm applied before each block. - """ - - def __init__(self, config: ViTMAEConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->ViTMAE -class TFViTMAEAttention(keras.layers.Layer): - def __init__(self, config: ViTMAEConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFViTMAESelfAttention(config, name="attention") - self.dense_output = TFViTMAESelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->ViTMAE -class TFViTMAEIntermediate(keras.layers.Layer): - def __init__(self, config: ViTMAEConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->ViTMAE -class TFViTMAEOutput(keras.layers.Layer): - def __init__(self, config: ViTMAEConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = hidden_states + input_tensor - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTLayer with ViT->ViTMAE -class TFViTMAELayer(keras.layers.Layer): - """This corresponds to the Block class in the timm implementation.""" - - def __init__(self, config: ViTMAEConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFViTMAEAttention(config, name="attention") - self.intermediate = TFViTMAEIntermediate(config, name="intermediate") - self.vit_output = TFViTMAEOutput(config, name="output") - - self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") - self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - attention_outputs = self.attention( - # in ViTMAE, layernorm is applied before self-attention - input_tensor=self.layernorm_before(inputs=hidden_states), - head_mask=head_mask, - output_attentions=output_attentions, - training=training, - ) - attention_output = attention_outputs[0] - - # first residual connection - hidden_states = attention_output + hidden_states - - # in ViTMAE, layernorm is also applied after self-attention - layer_output = self.layernorm_after(inputs=hidden_states) - - intermediate_output = self.intermediate(hidden_states=layer_output) - - # second residual connection is done here - layer_output = self.vit_output( - hidden_states=intermediate_output, input_tensor=hidden_states, training=training - ) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "vit_output", None) is not None: - with tf.name_scope(self.vit_output.name): - self.vit_output.build(None) - if getattr(self, "layernorm_before", None) is not None: - with tf.name_scope(self.layernorm_before.name): - self.layernorm_before.build([None, None, self.config.hidden_size]) - if getattr(self, "layernorm_after", None) is not None: - with tf.name_scope(self.layernorm_after.name): - self.layernorm_after.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.vit.modeling_tf_vit.TFViTEncoder with ViT->ViTMAE -class TFViTMAEEncoder(keras.layers.Layer): - def __init__(self, config: ViTMAEConfig, **kwargs): - super().__init__(**kwargs) - - self.layer = [TFViTMAELayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - head_mask: tf.Tensor, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states=hidden_states, - head_mask=head_mask[i], - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFViTMAEMainLayer(keras.layers.Layer): - config_class = ViTMAEConfig - - def __init__(self, config: ViTMAEConfig, **kwargs): - super().__init__(**kwargs) - - self.config = config - - self.embeddings = TFViTMAEEmbeddings(config, name="embeddings") - self.encoder = TFViTMAEEncoder(config, name="encoder") - self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") - - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings.patch_embeddings - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - pixel_values: TFModelInputType | None = None, - noise: tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - interpolate_pos_encoding: bool = False, - ) -> TFViTMAEModelOutput | tuple[tf.Tensor]: - embedding_output, mask, ids_restore = self.embeddings( - pixel_values=pixel_values, - training=training, - noise=noise, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - embedding_output, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(inputs=sequence_output) - - if not return_dict: - return (sequence_output, mask, ids_restore) + encoder_outputs[1:] - - return TFViTMAEModelOutput( - last_hidden_state=sequence_output, - mask=mask, - ids_restore=ids_restore, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "layernorm", None) is not None: - with tf.name_scope(self.layernorm.name): - self.layernorm.build([None, None, self.config.hidden_size]) - - -class TFViTMAEPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ViTMAEConfig - base_model_prefix = "vit" - main_input_name = "pixel_values" - - -VIT_MAE_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`ViTMAEConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -VIT_MAE_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] - for details. - - head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - - return_dict (`bool`, *optional*): - Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used - in eager mode, in graph mode the value will always be set to True. - - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). - - interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): - Whether to interpolate the position encodings at the encoder and decoder. -""" - - -@add_start_docstrings( - "The bare ViTMAE Model transformer outputting raw hidden-states without any specific head on top.", - VIT_MAE_START_DOCSTRING, -) -class TFViTMAEModel(TFViTMAEPreTrainedModel): - def __init__(self, config: ViTMAEConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.vit = TFViTMAEMainLayer(config, name="vit") - - def get_input_embeddings(self): - return self.vit.get_input_embeddings() - - @unpack_inputs - @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFViTMAEModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: TFModelInputType | None = None, - noise: tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - interpolate_pos_encoding: bool = False, - ) -> TFViTMAEModelOutput | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, TFViTMAEModel - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base") - >>> model = TFViTMAEModel.from_pretrained("facebook/vit-mae-base") - - >>> inputs = image_processor(images=image, return_tensors="tf") - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - outputs = self.vit( - pixel_values=pixel_values, - noise=noise, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "vit", None) is not None: - with tf.name_scope(self.vit.name): - self.vit.build(None) - - -class TFViTMAEDecoder(keras.layers.Layer): - def __init__(self, config, num_patches, **kwargs): - super().__init__(**kwargs) - self.decoder_embed = keras.layers.Dense(config.decoder_hidden_size, name="decoder_embed") - - decoder_config = deepcopy(config) - decoder_config.hidden_size = config.decoder_hidden_size - decoder_config.num_hidden_layers = config.decoder_num_hidden_layers - decoder_config.num_attention_heads = config.decoder_num_attention_heads - decoder_config.intermediate_size = config.decoder_intermediate_size - self.decoder_layers = [ - TFViTMAELayer(decoder_config, name=f"decoder_layers.{j}") for j in range(config.decoder_num_hidden_layers) - ] - - self.decoder_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="decoder_norm") - self.decoder_pred = keras.layers.Dense( - config.patch_size**2 * config.num_channels, - kernel_initializer=get_initializer(config.initializer_range), - name="decoder_pred", - ) # encoder to decoder - self.config = config - self.num_patches = num_patches - - def build(self, input_shape=None): - self.mask_token = self.add_weight( - shape=(1, 1, self.config.decoder_hidden_size), - initializer=tf.random_normal_initializer(stddev=self.config.initializer_range), - trainable=True, - name="mask_token", - ) - self.decoder_pos_embed = self.add_weight( - shape=(1, self.num_patches + 1, self.config.decoder_hidden_size), - initializer="zeros", - trainable=False, - name="decoder_pos_embed", - ) - decoder_pos_embed = get_2d_sincos_pos_embed( - self.decoder_pos_embed.shape[-1], - int(self.num_patches**0.5), - add_cls_token=True, - )[None, ...] - self.decoder_pos_embed.assign(decoder_pos_embed) - - if self.built: - return - self.built = True - if getattr(self, "decoder_embed", None) is not None: - with tf.name_scope(self.decoder_embed.name): - self.decoder_embed.build([None, None, self.config.hidden_size]) - if getattr(self, "decoder_norm", None) is not None: - with tf.name_scope(self.decoder_norm.name): - self.decoder_norm.build([None, None, self.config.decoder_hidden_size]) - if getattr(self, "decoder_pred", None) is not None: - with tf.name_scope(self.decoder_pred.name): - self.decoder_pred.build([None, None, self.config.decoder_hidden_size]) - if getattr(self, "decoder_layers", None) is not None: - for layer in self.decoder_layers: - with tf.name_scope(layer.name): - layer.build(None) - - def interpolate_pos_encoding(self, embeddings) -> tf.Tensor: - """ - This method is a modified version of the interpolation function for ViT-mae model at the decoder, that - allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher - resolution images. - - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 - """ - - # [batch_size, num_patches + 1, hidden_size] - _, num_positions, dim = shape_list(self.decoder_pos_embed) - - # -1 removes the class dimension since we later append it without interpolation - seq_len = shape_list(embeddings)[1] - 1 - num_positions = num_positions - 1 - - # Separation of class token and patch tokens - class_pos_embed = self.decoder_pos_embed[:, :1, :] - patch_pos_embed = self.decoder_pos_embed[:, 1:, :] - - # interpolate the position embeddings - patch_pos_embed = tf.image.resize( - images=tf.reshape(patch_pos_embed, shape=(1, 1, -1, dim)), - size=(1, seq_len), - method="bicubic", - ) - - # [1, seq_len, hidden_size] - patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) - # Adding the class token back - return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1) - - def call( - self, - hidden_states, - ids_restore, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - interpolate_pos_encoding=False, - ): - # embed tokens - x = self.decoder_embed(hidden_states) - # append mask tokens to sequence - mask_tokens = tf.tile( - self.mask_token, - (shape_list(x)[0], shape_list(ids_restore)[1] + 1 - shape_list(x)[1], 1), - ) - x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token - x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle - x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token - if interpolate_pos_encoding: - decoder_pos_embed = self.interpolate_pos_encoding(x) - else: - decoder_pos_embed = self.decoder_pos_embed - # add pos embed - hidden_states = x + decoder_pos_embed - # apply Transformer layers (blocks) - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - for i, layer_module in enumerate(self.decoder_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - head_mask=None, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - hidden_states = self.decoder_norm(hidden_states) - - # predictor projection - logits = self.decoder_pred(hidden_states) - - # remove cls token - logits = logits[:, 1:, :] - - if not return_dict: - return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None) - return TFViTMAEDecoderOutput(logits=logits, hidden_states=all_hidden_states, attentions=all_self_attentions) - - -@add_start_docstrings( - "The ViTMAE Model transformer with the decoder on top for self-supervised pre-training.", - VIT_MAE_START_DOCSTRING, -) -class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.config = config - - self.vit = TFViTMAEMainLayer(config, name="vit") - self.decoder = TFViTMAEDecoder( - config, - num_patches=self.vit.embeddings.num_patches, - name="decoder", - ) - - def get_input_embeddings(self): - return self.vit.get_input_embeddings() - - def _prune_heads(self, heads_to_prune): - raise NotImplementedError - - def patchify(self, pixel_values, interpolate_pos_encoding: bool = False): - """ - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`): - Pixel values. - interpolate_pos_encoding (`bool`, default `False`): - interpolation flag passed during the forward pass. - - Returns: - `tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: - Patchified pixel values. - """ - patch_size, num_channels = self.config.patch_size, self.config.num_channels - # make sure channels are last - if shape_list(pixel_values)[1] == num_channels: - pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - - # sanity checks - if not interpolate_pos_encoding: - tf.debugging.assert_equal( - shape_list(pixel_values)[1], - shape_list(pixel_values)[2], - message="Make sure the pixel values have a squared size", - ) - tf.debugging.assert_equal( - shape_list(pixel_values)[1] % patch_size, - 0, - message="Make sure the pixel values have a size that is divisible by the patch size", - ) - tf.debugging.assert_equal( - shape_list(pixel_values)[3], - num_channels, - message=( - "Make sure the number of channels of the pixel values is equal to the one set in the configuration" - ), - ) - - # patchify - batch_size = shape_list(pixel_values)[0] - num_patches_h = shape_list(pixel_values)[1] // patch_size - num_patches_w = shape_list(pixel_values)[2] // patch_size - patchified_pixel_values = tf.reshape( - pixel_values, - (batch_size, num_patches_h, patch_size, num_patches_w, patch_size, num_channels), - ) - patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values) - patchified_pixel_values = tf.reshape( - patchified_pixel_values, - (batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels), - ) - return patchified_pixel_values - - def unpatchify(self, patchified_pixel_values, original_image_size: tuple[int, int] | None = None): - """ - Args: - patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: - Patchified pixel values. - original_image_size (`tuple[int, int]`, *optional*): - Original image size. - - Returns: - `tf.Tensor` of shape `(batch_size, height, width, num_channels)`: - Pixel values. - """ - patch_size, num_channels = self.config.patch_size, self.config.num_channels - original_image_size = ( - original_image_size - if original_image_size is not None - else (self.config.image_size, self.config.image_size) - ) - original_height, original_width = original_image_size - num_patches_h = original_height // patch_size - num_patches_w = original_width // patch_size - # sanity check - tf.debugging.assert_equal( - num_patches_h * num_patches_w, - shape_list(patchified_pixel_values)[1], - message=f"The number of patches in the patchified pixel values is {shape_list(patchified_pixel_values)[1]} does not match the patches of original image {num_patches_w}*{num_patches_h}", - ) - - # unpatchify - batch_size = shape_list(patchified_pixel_values)[0] - patchified_pixel_values = tf.reshape( - patchified_pixel_values, - (batch_size, num_patches_h, num_patches_w, patch_size, patch_size, num_channels), - ) - patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values) - pixel_values = tf.reshape( - patchified_pixel_values, - (batch_size, num_patches_h * patch_size, num_patches_w * patch_size, num_channels), - ) - return pixel_values - - def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False): - """ - Args: - pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`): - Pixel values. - pred (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: - Predicted pixel values. - mask (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Tensor indicating which patches are masked (1) and which are not (0). - interpolate_pos_encoding (`bool`, *optional*, default `False`): - interpolation flag passed during the forward pass. - - Returns: - `tf.Tensor`: Pixel reconstruction loss. - """ - target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - if self.config.norm_pix_loss: - mean = tf.reduce_mean(target, axis=-1, keepdims=True) - var = tf.math.reduce_variance(target, axis=-1, keepdims=True) - target = (target - mean) / (var + 1.0e-6) ** 0.5 - - loss = (pred - target) ** 2 - loss = tf.reduce_mean(loss, axis=-1) # [batch_size, num_patches], mean loss per patch - - loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) # mean loss on removed patches - loss = tf.reshape(loss, (1,)) - return loss - - @unpack_inputs - @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - pixel_values: TFModelInputType | None = None, - noise: tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - interpolate_pos_encoding: bool = False, - ) -> TFViTMAEForPreTrainingOutput | tuple[tf.Tensor]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, TFViTMAEForPreTraining - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base") - >>> model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base") - - >>> inputs = image_processor(images=image, return_tensors="pt") - >>> outputs = model(**inputs) - >>> loss = outputs.loss - >>> mask = outputs.mask - >>> ids_restore = outputs.ids_restore - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.vit( - pixel_values=pixel_values, - noise=noise, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - latent = outputs.last_hidden_state - ids_restore = outputs.ids_restore - mask = outputs.mask - - # [batch_size, num_patches, patch_size**2*3] - decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding) - logits = decoder_outputs.logits - - loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding) - - if not return_dict: - output = (logits, mask, ids_restore) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFViTMAEForPreTrainingOutput( - loss=loss, - logits=logits, - mask=mask, - ids_restore=ids_restore, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "vit", None) is not None: - with tf.name_scope(self.vit.name): - self.vit.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -__all__ = ["TFViTMAEForPreTraining", "TFViTMAEModel", "TFViTMAEPreTrainedModel"] diff --git a/src/transformers/models/vivit/convert_vivit_flax_to_pytorch.py b/src/transformers/models/vivit/convert_vivit_flax_to_pytorch.py deleted file mode 100644 index bf6aa8e4a36b..000000000000 --- a/src/transformers/models/vivit/convert_vivit_flax_to_pytorch.py +++ /dev/null @@ -1,231 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert Flax ViViT checkpoints from the original repository to PyTorch. URL: -https://github.com/google-research/scenic/tree/main/scenic/projects/vivit -""" - -import argparse -import json -import os.path -from collections import OrderedDict - -import numpy as np -import requests -import torch -from flax.training.checkpoints import restore_checkpoint -from huggingface_hub import hf_hub_download - -from transformers import VivitConfig, VivitForVideoClassification, VivitImageProcessor -from transformers.image_utils import PILImageResampling - - -def download_checkpoint(path): - url = "https://storage.googleapis.com/scenic-bucket/vivit/kinetics_400/vivit_base_16x2_unfactorized/checkpoint" - - with open(path, "wb") as f: - with requests.get(url, stream=True) as req: - for chunk in req.iter_content(chunk_size=2048): - f.write(chunk) - - -def get_vivit_config() -> VivitConfig: - config = VivitConfig() - - config.num_labels = 400 - repo_id = "huggingface/label-files" - filename = "kinetics400-id2label.json" - - id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) - id2label = {int(k): v for k, v in id2label.items()} - config.id2label = id2label - config.label2id = {v: k for k, v in id2label.items()} - return config - - -# We will verify our results on a video of eating spaghetti -# Frame indices used: [ 47, 51, 55, 59, 63, 67, 71, 75, 80, 84, 88, 92, 96, 100, 104, 108, 113, 117, -# 121, 125, 129, 133, 137, 141, 146, 150, 154, 158, 162, 166, 170, 174] -def prepare_video(): - file = hf_hub_download( - repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti_32_frames.npy", repo_type="dataset" - ) - video = np.load(file) - return list(video) - - -def transform_attention(current: np.ndarray): - if np.ndim(current) == 2: - return transform_attention_bias(current) - - elif np.ndim(current) == 3: - return transform_attention_kernel(current) - - else: - raise Exception(f"Invalid number of dimensions: {np.ndim(current)}") - - -def transform_attention_bias(current: np.ndarray): - return current.flatten() - - -def transform_attention_kernel(current: np.ndarray): - return np.reshape(current, (current.shape[0], current.shape[1] * current.shape[2])).T - - -def transform_attention_output_weight(current: np.ndarray): - return np.reshape(current, (current.shape[0] * current.shape[1], current.shape[2])).T - - -def transform_state_encoder_block(state_dict, i): - state = state_dict["optimizer"]["target"]["Transformer"][f"encoderblock_{i}"] - - prefix = f"encoder.layer.{i}." - new_state = { - prefix + "intermediate.dense.bias": state["MlpBlock_0"]["Dense_0"]["bias"], - prefix + "intermediate.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_0"]["kernel"]), - prefix + "output.dense.bias": state["MlpBlock_0"]["Dense_1"]["bias"], - prefix + "output.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_1"]["kernel"]), - prefix + "layernorm_before.bias": state["LayerNorm_0"]["bias"], - prefix + "layernorm_before.weight": state["LayerNorm_0"]["scale"], - prefix + "layernorm_after.bias": state["LayerNorm_1"]["bias"], - prefix + "layernorm_after.weight": state["LayerNorm_1"]["scale"], - prefix + "attention.attention.query.bias": transform_attention( - state["MultiHeadDotProductAttention_0"]["query"]["bias"] - ), - prefix + "attention.attention.query.weight": transform_attention( - state["MultiHeadDotProductAttention_0"]["query"]["kernel"] - ), - prefix + "attention.attention.key.bias": transform_attention( - state["MultiHeadDotProductAttention_0"]["key"]["bias"] - ), - prefix + "attention.attention.key.weight": transform_attention( - state["MultiHeadDotProductAttention_0"]["key"]["kernel"] - ), - prefix + "attention.attention.value.bias": transform_attention( - state["MultiHeadDotProductAttention_0"]["value"]["bias"] - ), - prefix + "attention.attention.value.weight": transform_attention( - state["MultiHeadDotProductAttention_0"]["value"]["kernel"] - ), - prefix + "attention.output.dense.bias": state["MultiHeadDotProductAttention_0"]["out"]["bias"], - prefix + "attention.output.dense.weight": transform_attention_output_weight( - state["MultiHeadDotProductAttention_0"]["out"]["kernel"] - ), - } - - return new_state - - -def get_n_layers(state_dict): - return sum([1 if "encoderblock_" in k else 0 for k in state_dict["optimizer"]["target"]["Transformer"]]) - - -def transform_state(state_dict, classification_head=False): - transformer_layers = get_n_layers(state_dict) - - new_state = OrderedDict() - - new_state["layernorm.bias"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["bias"] - new_state["layernorm.weight"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["scale"] - - new_state["embeddings.patch_embeddings.projection.weight"] = np.transpose( - state_dict["optimizer"]["target"]["embedding"]["kernel"], (4, 3, 0, 1, 2) - ) - new_state["embeddings.patch_embeddings.projection.bias"] = state_dict["optimizer"]["target"]["embedding"]["bias"] - - new_state["embeddings.cls_token"] = state_dict["optimizer"]["target"]["cls"] - new_state["embeddings.position_embeddings"] = state_dict["optimizer"]["target"]["Transformer"]["posembed_input"][ - "pos_embedding" - ] - - for i in range(transformer_layers): - new_state.update(transform_state_encoder_block(state_dict, i)) - - if classification_head: - new_state = {"vivit." + k: v for k, v in new_state.items()} - new_state["classifier.weight"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["kernel"]) - new_state["classifier.bias"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["bias"]) - - return {k: torch.tensor(v) for k, v in new_state.items()} - - -# checks that image processor settings are the same as in the original implementation -# original: https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/data/video_tfrecord_dataset.py -# dataset specific config: -# https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/configs/kinetics400/vivit_base_k400.py -def get_processor() -> VivitImageProcessor: - extractor = VivitImageProcessor() - - assert extractor.do_resize is True - assert extractor.size == {"shortest_edge": 256} - assert extractor.do_center_crop is True - assert extractor.crop_size == {"width": 224, "height": 224} - assert extractor.resample == PILImageResampling.BILINEAR - - # here: https://github.com/deepmind/dmvr/blob/master/dmvr/modalities.py - # one can seen that add_image has default values for normalization_mean and normalization_std set to 0 and 1 - # which effectively means no normalization (and ViViT does not overwrite those when calling this func) - assert extractor.do_normalize is False - assert extractor.do_rescale is True - assert extractor.rescale_factor == 1 / 255 - - # zero-centering = True in original implementation - assert extractor.do_zero_centering is True - - return extractor - - -def convert(output_path: str): - flax_model_path = "checkpoint" - - if not os.path.exists(flax_model_path): - download_checkpoint(flax_model_path) - - state_dict = restore_checkpoint(flax_model_path, None) - new_state = transform_state(state_dict, classification_head=True) - - config = get_vivit_config() - - assert config.image_size == 224 - assert config.num_frames == 32 - - model = VivitForVideoClassification(config) - model.load_state_dict(new_state) - model.eval() - - extractor = get_processor() - - video = prepare_video() - inputs = extractor(video, return_tensors="pt") - - outputs = model(**inputs) - - expected_shape = torch.Size([1, 400]) - expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658]) - - assert outputs.logits.shape == expected_shape - assert torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4), outputs.logits[0, :5] - - model.save_pretrained(output_path) - extractor.save_pretrained(output_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument("--output_model_name", "-o", type=str, help="Output path for the converted HuggingFace model") - - args = parser.parse_args() - convert(args.output_model_name) diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py deleted file mode 100644 index bc5a396dcad4..000000000000 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ /dev/null @@ -1,1423 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax Wav2Vec2 model.""" - -from functools import partial -from typing import Optional, Union - -import flax -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_wav2vec2 import Wav2Vec2Config - - -logger = logging.get_logger(__name__) - - -@flax.struct.dataclass -class FlaxWav2Vec2BaseModelOutput(ModelOutput): - """ - Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions. - - Args: - last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`): - Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim` - being the dimension of the last convolutional layer. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: jnp.ndarray = None - extract_features: jnp.ndarray = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -@flax.struct.dataclass -class FlaxWav2Vec2ForPreTrainingOutput(ModelOutput): - """ - Output type of [`FlaxWav2Vec2ForPreTrainingOutput`], with potential hidden states and attentions. - - Args: - loss (*optional*, returned when model is in train mode, `jnp.ndarray` of shape `(1,)`): - Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official - paper](https://huggingface.co/papers/2006.11477). - projected_states (`jnp.ndarray` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): - Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked - projected quantized states. - projected_quantized_states (`jnp.ndarray` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): - Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive - target vectors for contrastive loss. - hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - projected_states: jnp.ndarray = None - projected_quantized_states: jnp.ndarray = None - codevector_perplexity: jnp.ndarray = None - hidden_states: Optional[tuple[jnp.ndarray]] = None - attentions: Optional[tuple[jnp.ndarray]] = None - - -def _compute_mask_indices( - shape: tuple[int, int], - mask_prob: float, - mask_length: int, - attention_mask: Optional[np.ndarray] = None, - min_masks: int = 0, -) -> np.ndarray: - """ - Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for - ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on - CPU as part of the preprocessing during training. - - Args: - shape: the shape for which to compute masks. - should be of size 2 where first element is batch size and 2nd is timesteps - mask_prob: - probability for each token to be chosen as start of the span to be masked. this will be multiplied by - number of timesteps divided by length of mask span to mask approximately this percentage of all elements. - however due to overlaps, the actual number will be smaller (unless no_overlap is True) - mask_length: size of the mask - min_masks: minimum number of masked spans - - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - if mask_length > sequence_length: - raise ValueError( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and" - f" `sequence_length`: {sequence_length}`" - ) - - # compute number of masked spans in batch - num_masked_spans = int(mask_prob * sequence_length / mask_length + np.random.rand(1).item()) - num_masked_spans = max(num_masked_spans, min_masks) - - # make sure num masked indices <= sequence_length - if num_masked_spans * mask_length > sequence_length: - num_masked_spans = sequence_length // mask_length - - # SpecAugment mask to fill - spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) - - # get random indices to mask - spec_aug_mask_idxs = np.array( - [ - np.random.choice(np.arange(sequence_length - (mask_length - 1)), num_masked_spans, replace=False) - for _ in range(batch_size) - ] - ) - - # expand masked indices to masked spans - spec_aug_mask_idxs = np.broadcast_to(spec_aug_mask_idxs[:, :, None], (batch_size, num_masked_spans, mask_length)) - spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, num_masked_spans * mask_length) - - offsets = np.arange(mask_length)[None, None, :] - offsets = np.broadcast_to(offsets, (batch_size, num_masked_spans, mask_length)).reshape( - batch_size, num_masked_spans * mask_length - ) - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # scatter indices to mask - np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) - - if attention_mask is not None: - # make sure padded input ids cannot be masked - spec_aug_mask = np.where(attention_mask, spec_aug_mask, False) - - return spec_aug_mask - - -def _sample_negative_indices(features_shape: tuple, num_negatives: int, attention_mask: Optional[np.ndarray] = None): - """ - Sample `num_negatives` vectors from feature vectors. - """ - batch_size, sequence_length, hidden_size = features_shape - if sequence_length <= 1: - raise ValueError( - "`features should have `sequence_length` > 1, but are of shape " - f"(batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})." - ) - - # get `num_negatives` random vector indices from the same utterance - sampled_negative_indices = [] - for batch_idx in range(batch_size): - high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1 - sampled_indices_slice = np.random.randint(0, high, size=(num_negatives * sequence_length,)) - sampled_negative_indices.append(sampled_indices_slice) - - sampled_negative_indices = np.asarray(sampled_negative_indices, dtype=np.int32) - - # generate indices of the positive vectors themselves, repeat them `num_negatives` times - feature_indices = np.broadcast_to(np.arange(sequence_length)[:, None], (sequence_length, num_negatives)).flatten() - - # avoid sampling the same positive vector, but keep the distribution uniform - sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1 - - # correct for batch size - for batch_idx in range(1, batch_size): - sampled_negative_indices[batch_idx] += batch_idx * sequence_length - - return sampled_negative_indices - - -WAV2VEC2_START_DOCSTRING = r""" - Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech - Representations](https://huggingface.co/papers/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael - Auli. - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - - -WAV2VEC2_INPUTS_DOCSTRING = r""" - Args: - input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file - into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library - (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). - To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion - into a tensor of type `jnp.ndarray`. See [`Wav2Vec2Processor.__call__`] for details. - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, - 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed - if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor - has `config.return_attention_mask == False`, such as - [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be - passed to avoid degraded performance when doing batched inference. For such models `input_values` should - simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly - different results depending on whether `input_values` is padded or not. - mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict - masked extracted features in *config.proj_codevector_dim* space. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class FlaxWav2Vec2LayerNormConvLayer(nn.Module): - config: Wav2Vec2Config - layer_id: int = 0 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1 - self.out_conv_dim = self.config.conv_dim[self.layer_id] - - self.conv = nn.Conv( - features=self.config.conv_dim[self.layer_id], - kernel_size=(self.config.conv_kernel[self.layer_id],), - strides=(self.config.conv_stride[self.layer_id],), - use_bias=self.config.conv_bias, - kernel_init=jax.nn.initializers.he_normal(), - padding="VALID", - dtype=self.dtype, - ) - self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.activation = ACT2FN[self.config.feat_extract_activation] - - def __call__(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -class FlaxConvWithWeightNorm(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.conv = nn.Conv( - features=self.config.hidden_size, - kernel_size=(self.config.num_conv_pos_embeddings,), - kernel_init=jax.nn.initializers.he_normal(), - padding="VALID", - feature_group_count=self.config.num_conv_pos_embedding_groups, - dtype=self.dtype, - ) - weight_shape = ( - self.conv.features, - self.conv.features // self.conv.feature_group_count, - self.conv.kernel_size[0], - ) - self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape) - self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]) - self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,)) - self.prev_padding = self.conv.kernel_size[0] // 2 - - def _get_normed_weights(self): - weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :] - normed_weight_v = jnp.divide(self.weight_v, weight_v_norm) - normed_kernel = jnp.multiply(normed_weight_v, self.weight_g) - return normed_kernel - - def __call__(self, hidden_states): - kernel = self._get_normed_weights() - hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0))) - hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states) - return hidden_states - - -class FlaxWav2Vec2PositionalConvEmbedding(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype) - self.activation = ACT2FN[self.config.feat_extract_activation] - self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0 - - def __call__(self, hidden_states): - hidden_states = hidden_states.transpose((0, 1, 2)) - - hidden_states = self.conv(hidden_states) - - if self.num_pad_remove > 0: - hidden_states = hidden_states[:, : -self.num_pad_remove, :] - hidden_states = self.activation(hidden_states) - - hidden_states = hidden_states.transpose((0, 1, 2)) - return hidden_states - - -class FlaxConvLayersCollection(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - if self.config.feat_extract_norm == "layer": - self.layers = [ - FlaxWav2Vec2LayerNormConvLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) - for i in range(self.config.num_feat_extract_layers) - ] - elif self.config.feat_extract_norm == "group": - raise NotImplementedError("At the moment only ``config.feat_extract_norm == 'layer'`` is supported") - else: - raise ValueError( - f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group'," - " 'layer']" - ) - - def __call__(self, hidden_states): - for i, conv_layer in enumerate(self.layers): - hidden_states = conv_layer(hidden_states) - return hidden_states - - -class FlaxWav2Vec2FeatureEncoder(nn.Module): - """Construct the features from raw audio waveform""" - - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype) - - def __call__(self, input_values, freeze_feature_encoder=False): - hidden_states = input_values[:, :, None] - hidden_states = self.conv_layers(hidden_states) - if freeze_feature_encoder: - hidden_states = jax.lax.stop_gradient(hidden_states) - return hidden_states - - -class FlaxWav2Vec2FeatureProjection(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.projection = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout) - - def __call__(self, hidden_states, deterministic=True): - norm_hidden_states = self.layer_norm(hidden_states) - hidden_states = self.projection(norm_hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states, norm_hidden_states - - -class FlaxWav2Vec2Attention(nn.Module): - config: Wav2Vec2Config - embed_dim: int - num_heads: int - dropout: float = 0.0 - bias: bool = True - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self) -> None: - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - - dense = partial( - nn.Dense, - self.embed_dim, - use_bias=self.bias, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() - self.out_proj = dense() - - self.dropout_layer = nn.Dropout(rate=self.dropout) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - def __call__( - self, - hidden_states: jnp.ndarray, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - """Input shape: Batch x Time x Channel""" - - # get query proj - query_states = self.q_proj(hidden_states) - - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - if attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class FlaxWav2Vec2FeedForward(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout) - - self.intermediate_dense = nn.Dense( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - if isinstance(self.config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[self.config.hidden_act] - else: - self.intermediate_act_fn = self.config.hidden_act - - self.output_dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout) - - def __call__(self, hidden_states, deterministic=True): - hidden_states = self.intermediate_dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic) - - hidden_states = self.output_dense(hidden_states) - hidden_states = self.output_dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.attention = FlaxWav2Vec2Attention( - config=self.config, - embed_dim=self.config.hidden_size, - num_heads=self.config.num_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout) - self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype) - self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False): - attn_residual = hidden_states - hidden_states = self.layer_norm(hidden_states) - hidden_states, attn_weights = self.attention( - hidden_states, attention_mask=attention_mask, deterministic=deterministic - ) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = attn_residual + hidden_states - hidden_states = hidden_states + self.feed_forward( - self.final_layer_norm(hidden_states), deterministic=deterministic - ) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.layers = [ - FlaxWav2Vec2EncoderLayerStableLayerNorm(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask=None, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer( - hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -class FlaxWav2Vec2StableLayerNormEncoder(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype) - self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout) - self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask=None, - deterministic=True, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): - if attention_mask is not None: - # make sure padded tokens are not attended to - hidden_states = jnp.where( - jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0 - ) - - position_embeddings = self.pos_conv_embed(hidden_states) - - hidden_states = hidden_states + position_embeddings - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = self.layer_norm(outputs[0]) - - # update the last element in `hidden_states` after applying `layernorm` above - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_state,) - - if not return_dict: - outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions - ) - - -class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module): - """ - Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH - GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information. - """ - - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.num_groups = self.config.num_codevector_groups - self.num_vars = self.config.num_codevectors_per_group - - if self.config.codevector_dim % self.num_groups != 0: - raise ValueError( - f"`config.codevector_dim {self.config.codevector_dim} must be divisible by" - f" `config.num_codevector_groups` {self.num_groups} for concatenation" - ) - - # storage for codebook variables (codewords) - self.codevectors = self.param( - "codevectors", - jax.nn.initializers.uniform(), - (1, self.num_groups * self.num_vars, self.config.codevector_dim // self.num_groups), - ) - self.weight_proj = nn.Dense( - self.num_groups * self.num_vars, - kernel_init=jax.nn.initializers.normal(1.0), - dtype=self.dtype, - ) - - @staticmethod - def _compute_perplexity(probs, mask=None): - if mask is not None: - mask_extended = jnp.broadcast_to(mask.flatten()[:, None, None], probs.shape) - probs = jnp.where(mask_extended, probs, jnp.zeros_like(probs)) - marginal_probs = probs.sum(axis=0) / mask.sum() - else: - marginal_probs = probs.mean(axis=0) - - perplexity = jnp.exp(-jnp.sum(marginal_probs * jnp.log(marginal_probs + 1e-7), axis=-1)).sum() - return perplexity - - def __call__(self, hidden_states, mask_time_indices=None, deterministic=True, temperature=1): - batch_size, sequence_length, hidden_size = hidden_states.shape - - # project to codevector dim - hidden_states = self.weight_proj(hidden_states) - hidden_states = hidden_states.reshape(batch_size * sequence_length * self.num_groups, -1) - - if not deterministic: - # sample code vector probs via gumbel in differentiateable way - gumbel_rng = self.make_rng("gumbel") - gumbels = jax.random.gumbel(gumbel_rng, hidden_states.shape) - codevector_probs = nn.softmax((hidden_states + gumbels) / temperature) - - # compute perplexity - codevector_soft_dist = nn.softmax( - hidden_states.reshape(batch_size * sequence_length, self.num_groups, -1), axis=-1 - ) - perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) - else: - # take argmax in non-differentiable way - # comptute hard codevector distribution (one hot) - codevector_idx = hidden_states.argmax(axis=-1) - codevector_probs = jax.nn.one_hot(codevector_idx, hidden_states.shape[-1]) * 1.0 - codevector_probs = codevector_probs.reshape(batch_size * sequence_length, self.num_groups, -1) - perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) - - codevector_probs = codevector_probs.reshape(batch_size * sequence_length, -1) - # use probs to retrieve codevectors - codevectors_per_group = jnp.expand_dims(codevector_probs, axis=-1) * self.codevectors - codevectors = codevectors_per_group.reshape(batch_size * sequence_length, self.num_groups, self.num_vars, -1) - codevectors = codevectors.sum(-2).reshape(batch_size, sequence_length, -1) - - return codevectors, perplexity - - -class FlaxWav2Vec2Adapter(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - # hidden_states require down-projection if feature dims don't match - if self.config.output_hidden_size != self.config.hidden_size: - self.proj = nn.Dense( - self.config.output_hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - else: - self.proj = self.proj_layer_norm = None - - self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype) - - def __call__(self, hidden_states, deterministic=True): - # down-project hidden_states if required - if self.proj is not None and self.proj_layer_norm is not None: - hidden_states = self.proj(hidden_states) - hidden_states = self.proj_layer_norm(hidden_states) - - hidden_states = self.layers(hidden_states) - - return hidden_states - - -class FlaxWav2Vec2AdapterLayer(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.conv = nn.Conv( - features=2 * self.config.output_hidden_size, - kernel_size=(self.config.adapter_kernel_size,), - strides=(self.config.adapter_stride,), - padding=((1, 1),), - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = nn.glu(hidden_states, axis=2) - - return hidden_states - - -class FlaxWav2Vec2AdapterLayersCollection(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.layers = [ - FlaxWav2Vec2AdapterLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_adapter_layers) - ] - - def __call__(self, hidden_states): - for conv_layer in self.layers: - hidden_states = conv_layer(hidden_states) - - return hidden_states - - -class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = Wav2Vec2Config - base_model_prefix: str = "wav2vec2" - main_input_name = "input_values" - module_class: nn.Module = None - - def __init__( - self, - config: Wav2Vec2Config, - input_shape: tuple = (1, 1024), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_values = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_values) - params_rng, dropout_rng = jax.random.split(rng, 2) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) - def __call__( - self, - input_values, - attention_mask=None, - mask_time_indices=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - freeze_feature_encoder: bool = False, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - batch_size, sequence_length = input_values.shape - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - return self.module.apply( - inputs, - jnp.array(input_values, dtype="f4"), - jnp.array(attention_mask, dtype="i4"), - mask_time_indices, - not train, - output_attentions, - output_hidden_states, - freeze_feature_encoder, - return_dict, - rngs=rngs, - ) - - def _get_feat_extract_output_lengths( - self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None - ): - return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) - - -class FlaxWav2Vec2Module(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype) - self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype) - self.masked_spec_embed = self.param( - "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,) - ) - - if self.config.do_stable_layer_norm: - self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype) - else: - raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.") - - self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None - - def __call__( - self, - input_values, - attention_mask=None, - mask_time_indices=None, - deterministic=True, - output_attentions=None, - output_hidden_states=None, - freeze_feature_encoder=False, - return_dict=None, - ): - extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder) - - # make sure that no loss is computed on padded inputs - if attention_mask is not None: - # compute reduced attention_mask corresponding to feature vectors - attention_mask = self._get_feature_vector_attention_mask( - extract_features.shape[1], attention_mask, add_adapter=False - ) - - hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic) - if mask_time_indices is not None: # apply SpecAugment along time axis with given indices - hidden_states = jnp.where( - jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape), - jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape), - hidden_states, - ) - - encoder_outputs = self.encoder( - hidden_states, - attention_mask=attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = encoder_outputs[0] - - if self.adapter is not None: - hidden_states = self.adapter(hidden_states) - - if not return_dict: - return (hidden_states, extract_features) + encoder_outputs[1:] - - return FlaxWav2Vec2BaseModelOutput( - last_hidden_state=hidden_states, - extract_features=extract_features, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def _get_feat_extract_output_lengths( - self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None - ): - """ - Computes the output length of the convolutional layers - """ - - add_adapter = self.config.add_adapter if add_adapter is None else add_adapter - - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return (input_length - kernel_size) // stride + 1 - - for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): - input_lengths = _conv_out_length(input_lengths, kernel_size, stride) - - if add_adapter: - for _ in range(self.config.num_adapter_layers): - input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) - - return input_lengths - - def _get_feature_vector_attention_mask( - self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None - ): - # Effectively attention_mask.sum(-1), but not inplace to be able to run - # on inference mode. - non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1] - - output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) - - batch_size = attention_mask.shape[0] - - attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype) - # these two operations makes sure that all values - # before the output lengths indices are attended to - attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) - attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") - return attention_mask - - -@add_start_docstrings( - "The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.", - WAV2VEC2_START_DOCSTRING, -) -class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel): - module_class = FlaxWav2Vec2Module - - -FLAX_WAV2VEC2_MODEL_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import AutoProcessor, FlaxWav2Vec2Model - >>> from datasets import load_dataset - - >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-lv60") - >>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60") - - - >>> def map_to_array(example): - ... example["speech"] = example["audio"]["array"] - ... return example - - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.map(map_to_array) - - >>> input_values = processor( - ... ds["speech"][0], sampling_rate=16_000, return_tensors="np" - ... ).input_values # Batch size 1 - >>> hidden_states = model(input_values).last_hidden_state - ``` -""" - -overwrite_call_docstring( - FlaxWav2Vec2Model, - WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING, -) -append_replace_return_docstrings( - FlaxWav2Vec2Model, output_type=FlaxWav2Vec2BaseModelOutput, config_class=Wav2Vec2Config -) - - -class FlaxWav2Vec2ForCTCModule(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.final_dropout) - self.lm_head = nn.Dense( - self.config.vocab_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - - def __call__( - self, - input_values, - attention_mask=None, - mask_time_indices=None, - deterministic=True, - output_attentions=None, - output_hidden_states=None, - freeze_feature_encoder=False, - return_dict=None, - ): - outputs = self.wav2vec2( - input_values, - attention_mask=attention_mask, - mask_time_indices=mask_time_indices, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - freeze_feature_encoder=freeze_feature_encoder, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - - logits = self.lm_head(hidden_states) - - if not return_dict: - return (logits,) + outputs[2:] - - return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - - def _get_feat_extract_output_lengths( - self, - input_lengths: Union[jnp.ndarray, int], - add_adapter: Optional[bool] = None, - ): - """ - Computes the output length of the convolutional layers - """ - - add_adapter = self.config.add_adapter if add_adapter is None else add_adapter - - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return (input_length - kernel_size) // stride + 1 - - for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): - input_lengths = _conv_out_length(input_lengths, kernel_size, stride) - - if add_adapter: - for _ in range(self.config.num_adapter_layers): - input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) - - return input_lengths - - -@add_start_docstrings( - "Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).", - WAV2VEC2_START_DOCSTRING, -) -class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel): - module_class = FlaxWav2Vec2ForCTCModule - - -FLAX_WAV2VEC2_FOR_CTC_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoProcessor, FlaxWav2Vec2ForCTC - >>> from datasets import load_dataset - - >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-960h-lv60") - >>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60") - - - >>> def map_to_array(example): - ... example["speech"] = example["audio"]["array"] - ... return example - - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.map(map_to_array) - - >>> input_values = processor( - ... ds["speech"][0], sampling_rate=16_000, return_tensors="np" - ... ).input_values # Batch size 1 - >>> logits = model(input_values).logits - >>> predicted_ids = jnp.argmax(logits, axis=-1) - - >>> transcription = processor.decode(predicted_ids[0]) - >>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST" - ``` -""" - -overwrite_call_docstring( - FlaxWav2Vec2ForCTC, - WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING, -) -append_replace_return_docstrings(FlaxWav2Vec2ForCTC, output_type=FlaxCausalLMOutput, config_class=Wav2Vec2Config) - - -class FlaxWav2Vec2ForPreTrainingModule(nn.Module): - config: Wav2Vec2Config - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype) - self.dropout_features = nn.Dropout(self.config.feat_quantizer_dropout) - - self.quantizer = FlaxWav2Vec2GumbelVectorQuantizer(self.config, dtype=self.dtype) - self.project_q = nn.Dense( - self.config.proj_codevector_dim, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.project_hid = nn.Dense( - self.config.proj_codevector_dim, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - - def __call__( - self, - input_values, - attention_mask=None, - mask_time_indices=None, - gumbel_temperature: int = 1, - deterministic: bool = True, - output_attentions=None, - output_hidden_states=None, - freeze_feature_encoder=False, - return_dict=None, - ): - r""" - Returns: - - Example: - - ```python - - ```""" - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.wav2vec2( - input_values, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - mask_time_indices=mask_time_indices, - deterministic=deterministic, - freeze_feature_encoder=freeze_feature_encoder, - return_dict=return_dict, - ) - - # project all transformed features (including masked) to final vq dim - transformer_features = self.project_hid(outputs[0]) - - # quantize all (unmasked) extracted features and project to final vq dim - extract_features = self.dropout_features(outputs[1], deterministic=deterministic) - quantized_features, codevector_perplexity = self.quantizer( - extract_features, mask_time_indices, deterministic=deterministic, temperature=gumbel_temperature - ) - quantized_features = self.project_q(quantized_features) - - if not return_dict: - return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:] - - return FlaxWav2Vec2ForPreTrainingOutput( - projected_states=transformer_features, - projected_quantized_states=quantized_features, - codevector_perplexity=codevector_perplexity, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def _get_feat_extract_output_lengths( - self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None - ): - """ - Computes the output length of the convolutional layers - """ - - add_adapter = self.config.add_adapter if add_adapter is None else add_adapter - - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return (input_length - kernel_size) // stride + 1 - - for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): - input_lengths = _conv_out_length(input_lengths, kernel_size, stride) - - if add_adapter: - for _ in range(self.config.num_adapter_layers): - input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) - - return input_lengths - - -@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV2VEC2_START_DOCSTRING) -class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel): - module_class = FlaxWav2Vec2ForPreTrainingModule - - @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) - # overwrite since has `gumbel_temperature` input - def __call__( - self, - input_values, - attention_mask=None, - mask_time_indices=None, - gumbel_temperature: int = 1, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - gumbel_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - freeze_feature_encoder: bool = False, - return_dict: Optional[bool] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - batch_size, sequence_length = input_values.shape - - if attention_mask is None: - attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - if gumbel_rng is not None: - rngs["gumbel"] = gumbel_rng - - inputs = {"params": params or self.params} - - return self.module.apply( - inputs, - jnp.array(input_values, dtype="f4"), - jnp.array(attention_mask, dtype="i4"), - mask_time_indices, - gumbel_temperature, - not train, - output_attentions, - output_hidden_states, - freeze_feature_encoder, - return_dict, - rngs=rngs, - ) - - -FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> import optax - >>> import numpy as np - >>> import jax.numpy as jnp - >>> from transformers import AutoFeatureExtractor, FlaxWav2Vec2ForPreTraining - >>> from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices - >>> from datasets import load_dataset - - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-large-lv60") - >>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60") - - - >>> def map_to_array(example): - ... example["speech"] = example["audio"]["array"] - ... return example - - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.map(map_to_array) - - >>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1 - - >>> # compute masked indices - >>> batch_size, raw_sequence_length = input_values.shape - >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) - >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2) - - >>> outputs = model(input_values, mask_time_indices=mask_time_indices) - - >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states) - >>> cosine_sim = optax.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states) - - >>> # show that cosine similarity is much higher than random - >>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5 - ``` -""" - -overwrite_call_docstring( - FlaxWav2Vec2ForPreTraining, - WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING, -) -append_replace_return_docstrings( - FlaxWav2Vec2ForPreTraining, output_type=FlaxWav2Vec2ForPreTrainingOutput, config_class=Wav2Vec2Config -) - - -__all__ = ["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"] diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py deleted file mode 100644 index 54011bb969fd..000000000000 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ /dev/null @@ -1,1855 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow Wav2Vec2 model.""" - -from __future__ import annotations - -import warnings -from dataclasses import dataclass -from typing import Any - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput -from ...modeling_tf_utils import ( - TFPreTrainedModel, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_wav2vec2 import Wav2Vec2Config - - -logger = logging.get_logger(__name__) - - -_HIDDEN_STATES_START_POSITION = 2 - -_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h" -_CONFIG_FOR_DOC = "Wav2Vec2Config" - - -LARGE_NEGATIVE = -1e8 - - -@dataclass -class TFWav2Vec2BaseModelOutput(ModelOutput): - """ - Output type of [`TFWav2Vec2BaseModelOutput`], with potential hidden states and attentions. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - extract_features (`tf.Tensor` of shape `(batch_size, sequence_length, conv_dim[-1])`): - Sequence of extracted feature vectors of the last convolutional layer of the model. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: tf.Tensor | None = None - extract_features: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor] | None = None - attentions: tuple[tf.Tensor] | None = None - - -def _sample_without_replacement(distribution, num_samples): - """ - Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see - https://github.com/tensorflow/tensorflow/issues/9260 for more info - """ - z = -tf.math.log(tf.random.uniform(shape_list(distribution), 0, 1)) - _, indices = tf.nn.top_k(distribution + z, num_samples) - return indices - - -def _scatter_values_on_batch_indices(values, batch_indices, output_shape): - """ - Scatter function as in PyTorch with indices in format (batch_dim, indices) - """ - indices_shape = shape_list(batch_indices) - # broadcast batch dim to indices_shape - broad_casted_batch_dims = tf.reshape( - tf.broadcast_to(tf.expand_dims(tf.range(indices_shape[0]), axis=-1), indices_shape), [1, -1] - ) - # transform batch_indices to pair_indices - pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0)) - # scatter values to pair indices - return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), output_shape) - - -def _compute_mask_indices( - shape: tuple[int, int], - mask_prob: float, - mask_length: int, - min_masks: int = 0, -) -> tf.Tensor: - """ - Computes random mask spans for a given shape - - Args: - shape: the shape for which to compute masks. - should be of size 2 where first element is batch size and 2nd is timesteps - attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements - mask_prob: - probability for each token to be chosen as start of the span to be masked. this will be multiplied by - number of timesteps divided by length of mask span to mask approximately this percentage of all elements. - however due to overlaps, the actual number will be smaller (unless no_overlap is True) - mask_length: size of the mask - min_masks: minimum number of masked spans - - Adapted from [fairseq's - data_utils.py](https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376). - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - tf.debugging.assert_less( - mask_length, - sequence_length, - message=( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and" - f" `sequence_length`: {sequence_length}`" - ), - ) - - # compute number of masked spans in batch - num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,)) - num_masked_spans = tf.maximum(num_masked_spans, min_masks) - num_masked_spans = tf.cast(num_masked_spans, tf.int32) - - # make sure num masked indices <= sequence_length - num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans) - num_masked_spans = tf.squeeze(num_masked_spans) - - # SpecAugment mask to fill - spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32) - - # uniform distribution to sample from, make sure that offset samples are < sequence_length - uniform_dist = tf.ones((batch_size, sequence_length - (mask_length - 1))) - - # get random indices to mask - spec_aug_mask_idxs = _sample_without_replacement(uniform_dist, num_masked_spans) - - # expand masked indices to masked spans - spec_aug_mask_idxs = tf.expand_dims(spec_aug_mask_idxs, -1) - spec_aug_mask_idxs = tf.tile(spec_aug_mask_idxs, (1, 1, mask_length)) - spec_aug_mask_idxs = tf.reshape(spec_aug_mask_idxs, (batch_size, num_masked_spans * mask_length)) - - offsets = tf.range(mask_length)[tf.newaxis, tf.newaxis, :] - offsets = tf.tile(offsets, (batch_size, num_masked_spans, 1)) - offsets = tf.reshape(offsets, (batch_size, num_masked_spans * mask_length)) - - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # scatter indices to mask - spec_aug_mask = _scatter_values_on_batch_indices( - tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask) - ) - - return spec_aug_mask - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -class TFWav2Vec2GroupNorm(keras.layers.Layer): - """ - From tensorflow-addons https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization - """ - - def __init__( - self, - groups: int = 32, - axis: int = -1, - epsilon: float = 1e-3, - center: bool = True, - scale: bool = True, - beta_initializer: keras.initializers.Initializer = "zeros", - gamma_initializer: keras.initializers.Initializer = "ones", - beta_regularizer: keras.regularizers.Regularizer = None, - gamma_regularizer: keras.regularizers.Regularizer = None, - beta_constraint: keras.constraints.Constraint = None, - gamma_constraint: keras.constraints.Constraint = None, - **kwargs, - ): - super().__init__(**kwargs) - self.supports_masking = True - self.groups = groups - self.axis = axis - self.epsilon = epsilon - self.center = center - self.scale = scale - self.beta_initializer = keras.initializers.get(beta_initializer) - self.gamma_initializer = keras.initializers.get(gamma_initializer) - self.beta_regularizer = keras.regularizers.get(beta_regularizer) - self.gamma_regularizer = keras.regularizers.get(gamma_regularizer) - self.beta_constraint = keras.constraints.get(beta_constraint) - self.gamma_constraint = keras.constraints.get(gamma_constraint) - self._check_axis() - - def build(self, input_shape): - self._check_if_input_shape_is_none(input_shape) - self._set_number_of_groups_for_instance_norm(input_shape) - self._check_size_of_dimensions(input_shape) - self._create_input_spec(input_shape) - - self._add_gamma_weight(input_shape) - self._add_beta_weight(input_shape) - self.built = True - super().build(input_shape) - - def call(self, inputs): - input_shape = keras.backend.int_shape(inputs) - tensor_input_shape = tf.shape(inputs) - - reshaped_inputs, group_shape = self._reshape_into_groups(inputs, input_shape, tensor_input_shape) - - normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape) - - is_instance_norm = (input_shape[self.axis] // self.groups) == 1 - if not is_instance_norm: - outputs = tf.reshape(normalized_inputs, tensor_input_shape) - else: - outputs = normalized_inputs - - return outputs - - def get_config(self): - config = { - "groups": self.groups, - "axis": self.axis, - "epsilon": self.epsilon, - "center": self.center, - "scale": self.scale, - "beta_initializer": keras.initializers.serialize(self.beta_initializer), - "gamma_initializer": keras.initializers.serialize(self.gamma_initializer), - "beta_regularizer": keras.regularizers.serialize(self.beta_regularizer), - "gamma_regularizer": keras.regularizers.serialize(self.gamma_regularizer), - "beta_constraint": keras.constraints.serialize(self.beta_constraint), - "gamma_constraint": keras.constraints.serialize(self.gamma_constraint), - } - base_config = super().get_config() - return {**base_config, **config} - - def compute_output_shape(self, input_shape): - return input_shape - - def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape): - group_shape = [tensor_input_shape[i] for i in range(len(input_shape))] - is_instance_norm = (input_shape[self.axis] // self.groups) == 1 - if not is_instance_norm: - group_shape[self.axis] = input_shape[self.axis] // self.groups - group_shape.insert(self.axis, self.groups) - group_shape = tf.stack(group_shape) - reshaped_inputs = tf.reshape(inputs, group_shape) - return reshaped_inputs, group_shape - else: - return inputs, group_shape - - def _apply_normalization(self, reshaped_inputs, input_shape): - group_shape = keras.backend.int_shape(reshaped_inputs) - group_reduction_axes = list(range(1, len(group_shape))) - is_instance_norm = (input_shape[self.axis] // self.groups) == 1 - if not is_instance_norm: - axis = -2 if self.axis == -1 else self.axis - 1 - else: - axis = -1 if self.axis == -1 else self.axis - 1 - group_reduction_axes.pop(axis) - - mean, variance = tf.nn.moments(reshaped_inputs, group_reduction_axes, keepdims=True) - - gamma, beta = self._get_reshaped_weights(input_shape) - normalized_inputs = tf.nn.batch_normalization( - reshaped_inputs, - mean=mean, - variance=variance, - scale=gamma, - offset=beta, - variance_epsilon=self.epsilon, - ) - return normalized_inputs - - def _get_reshaped_weights(self, input_shape): - broadcast_shape = self._create_broadcast_shape(input_shape) - gamma = None - beta = None - if self.scale: - gamma = tf.reshape(self.gamma, broadcast_shape) - - if self.center: - beta = tf.reshape(self.beta, broadcast_shape) - return gamma, beta - - def _check_if_input_shape_is_none(self, input_shape): - dim = input_shape[self.axis] - if dim is None: - raise ValueError( - "Axis " - + str(self.axis) - + " of input tensor should have a defined dimension but the layer received an input with shape " - + str(input_shape) - + "." - ) - - def _set_number_of_groups_for_instance_norm(self, input_shape): - dim = input_shape[self.axis] - - if self.groups == -1: - self.groups = dim - - def _check_size_of_dimensions(self, input_shape): - dim = input_shape[self.axis] - if dim < self.groups: - raise ValueError( - "Number of groups (" - + str(self.groups) - + ") cannot be more than the number of channels (" - + str(dim) - + ")." - ) - - if dim % self.groups != 0: - raise ValueError( - "Number of groups (" - + str(self.groups) - + ") must be a multiple of the number of channels (" - + str(dim) - + ")." - ) - - def _check_axis(self): - if self.axis == 0: - raise ValueError( - "You are trying to normalize your batch axis. Do you want to use tf.layer.batch_normalization instead" - ) - - def _create_input_spec(self, input_shape): - dim = input_shape[self.axis] - self.input_spec = keras.layers.InputSpec(ndim=len(input_shape), axes={self.axis: dim}) - - def _add_gamma_weight(self, input_shape): - dim = input_shape[self.axis] - shape = (dim,) - - if self.scale: - self.gamma = self.add_weight( - shape=shape, - name="gamma", - initializer=self.gamma_initializer, - regularizer=self.gamma_regularizer, - constraint=self.gamma_constraint, - ) - else: - self.gamma = None - - def _add_beta_weight(self, input_shape): - dim = input_shape[self.axis] - shape = (dim,) - - if self.center: - self.beta = self.add_weight( - shape=shape, - name="beta", - initializer=self.beta_initializer, - regularizer=self.beta_regularizer, - constraint=self.beta_constraint, - ) - else: - self.beta = None - - def _create_broadcast_shape(self, input_shape): - broadcast_shape = [1] * len(input_shape) - is_instance_norm = (input_shape[self.axis] // self.groups) == 1 - if not is_instance_norm: - broadcast_shape[self.axis] = input_shape[self.axis] // self.groups - broadcast_shape.insert(self.axis, self.groups) - else: - broadcast_shape[self.axis] = self.groups - return broadcast_shape - - -class TFWav2Vec2WeightNormConv1D(keras.layers.Conv1D): - """Adapted from https://www.tensorflow.org/probability/api_docs/python/tfp/layers/weight_norm/WeightNorm""" - - def __init__(self, filters, kernel_size, groups, explicit_padding, **kwargs): - super().__init__( - filters=filters, - kernel_size=kernel_size, - groups=groups, - padding="valid", - use_bias=True, - bias_initializer="he_normal", - **kwargs, - ) - self.explicit_padding = explicit_padding - self.filter_axis = 2 - self.kernel_norm_axes = tf.constant([0, 1]) - - def _init_norm(self): - """Set the norm of the weight vector.""" - kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.weight_v), axis=self.kernel_norm_axes)) - self.weight_g.assign(kernel_norm[:, tf.newaxis, tf.newaxis]) - - def _normalize_kernel(self): - """Generate normalized weights.""" - kernel = tf.nn.l2_normalize(self.weight_v, axis=self.kernel_norm_axes) * tf.transpose(self.weight_g) - self.kernel = tf.transpose(kernel) - - def build(self, input_shape): - if not self.built: - super().build(input_shape) - - self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True) - self.weight_v = self.kernel - - self.weight_g = self.add_weight( - name="weight_g", - shape=(int(self.weight_v.shape[self.filter_axis]), 1, 1), - initializer="ones", - dtype=self.weight_v.dtype, - trainable=True, - ) - self._init_norm() - self.bias = self.add_weight(name="bias", shape=(self.filters,), initializer="zeros", trainable=True) - - def call(self, inputs): - # TODO Matt: Assigning to attributes in call() is deeply sinful in TensorFlow, as it should be idempotent. - # This whole layer should be replaced by a layer that doesn't inherit from Conv1D, but instead calls - # a functional 1d convolution with normalized weights that it generates (but does not store!) - self._normalize_kernel() - - padded_inputs = tf.pad(inputs, ((0, 0), (self.explicit_padding, self.explicit_padding), (0, 0))) - output = super().call(padded_inputs) - - return output - - -class TFWav2Vec2NoLayerNormConvLayer(keras.layers.Layer): - def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = keras.layers.Conv1D( - filters=self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - strides=config.conv_stride[layer_id], - use_bias=config.conv_bias, - name="conv", - ) - self.activation = get_tf_activation(config.feat_extract_activation) - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.conv(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build([None, None, self.in_conv_dim]) - - -class TFWav2Vec2LayerNormConvLayer(keras.layers.Layer): - def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = keras.layers.Conv1D( - filters=self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - strides=config.conv_stride[layer_id], - use_bias=config.conv_bias, - name="conv", - ) - self.layer_norm = keras.layers.LayerNormalization(name="layer_norm", epsilon=config.layer_norm_eps) - self.activation = get_tf_activation(config.feat_extract_activation) - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.conv(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build([None, None, self.in_conv_dim]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.out_conv_dim]) - - -class TFWav2Vec2GroupNormConvLayer(keras.layers.Layer): - def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = keras.layers.Conv1D( - filters=self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - strides=config.conv_stride[layer_id], - use_bias=config.conv_bias, - name="conv", - ) - self.activation = get_tf_activation(config.feat_extract_activation) - self.layer_norm = TFWav2Vec2GroupNorm( - groups=self.out_conv_dim, epsilon=config.layer_norm_eps, name="layer_norm" - ) - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.conv(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build([None, None, self.in_conv_dim]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.out_conv_dim]) - - -class TFWav2Vec2PositionalConvEmbedding(keras.layers.Layer): - def __init__(self, config: Wav2Vec2Config, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.conv = TFWav2Vec2WeightNormConv1D( - filters=config.hidden_size, - kernel_size=config.num_conv_pos_embeddings, - groups=config.num_conv_pos_embedding_groups, - explicit_padding=config.num_conv_pos_embeddings // 2, - name="conv", - ) - self.padding = TFWav2Vec2SamePadLayer(config.num_conv_pos_embeddings) - self.activation = get_tf_activation(config.feat_extract_activation) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.conv(hidden_states) - hidden_states = self.padding(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv", None) is not None: - with tf.name_scope(self.conv.name): - self.conv.build([None, None, self.config.hidden_size]) - - -class TFWav2Vec2SamePadLayer(keras.layers.Layer): - def __init__(self, num_conv_pos_embeddings, **kwargs): - super().__init__(**kwargs) - self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 - - def call(self, hidden_states): - if self.num_pad_remove > 0: - hidden_states = hidden_states[:, : -self.num_pad_remove, :] - return hidden_states - - -class TFWav2Vec2FeatureEncoder(keras.layers.Layer): - def __init__(self, config: Wav2Vec2Config, **kwargs: Any) -> None: - super().__init__(**kwargs) - - if config.feat_extract_norm == "group": - conv_layers = [TFWav2Vec2GroupNormConvLayer(config, layer_id=0, name=f"conv_layers.{0}")] + [ - TFWav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1, name=f"conv_layers.{i + 1}") - for i in range(config.num_feat_extract_layers - 1) - ] - elif config.feat_extract_norm == "layer": - conv_layers = [ - TFWav2Vec2LayerNormConvLayer(config, layer_id=i, name=f"conv_layers.{i}") - for i in range(config.num_feat_extract_layers) - ] - else: - raise ValueError( - f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" - ) - self.conv_layers = conv_layers - - def call(self, input_values): - hidden_states = tf.expand_dims(input_values, -1) - for conv_layer in self.conv_layers: - hidden_states = conv_layer(hidden_states) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv_layers", None) is not None: - for conv_layer in self.conv_layers: - with tf.name_scope(conv_layer.name): - conv_layer.build(None) - - -class TFWav2Vec2FeatureExtractor(TFWav2Vec2FeatureEncoder): - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - warnings.warn( - f"The class `{self.__class__.__name__}` has been depreciated " - "and will be removed in Transformers v5. " - f"Use `{self.__class__.__bases__[0].__name__}` instead.", - FutureWarning, - ) - - -class TFWav2Vec2FeatureProjection(keras.layers.Layer): - def __init__(self, config: Wav2Vec2Config, **kwargs): - super().__init__(**kwargs) - - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.projection = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer="zeros", - name="projection", - ) - self.dropout = keras.layers.Dropout(rate=config.feat_proj_dropout) - self.config = config - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - norm_hidden_states = self.layer_norm(hidden_states) - hidden_states = self.projection(norm_hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - return hidden_states, norm_hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.conv_dim[-1]]) - if getattr(self, "projection", None) is not None: - with tf.name_scope(self.projection.name): - self.projection.build([None, None, self.config.conv_dim[-1]]) - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFWav2Vec2 -class TFWav2Vec2Attention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -class TFWav2Vec2FeedForward(keras.layers.Layer): - def __init__(self, config: Wav2Vec2Config, **kwargs): - super().__init__(**kwargs) - - self.intermediate_dropout = keras.layers.Dropout(config.activation_dropout) - - self.intermediate_dense = keras.layers.Dense( - units=config.intermediate_size, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer="zeros", - name="intermediate_dense", - ) - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - - self.output_dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - bias_initializer="zeros", - name="output_dense", - ) - self.output_dropout = keras.layers.Dropout(config.hidden_dropout) - self.config = config - - def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.intermediate_dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - hidden_states = self.intermediate_dropout(hidden_states, training=training) - - hidden_states = self.output_dense(hidden_states) - hidden_states = self.output_dropout(hidden_states, training=training) - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "intermediate_dense", None) is not None: - with tf.name_scope(self.intermediate_dense.name): - self.intermediate_dense.build([None, None, self.config.hidden_size]) - if getattr(self, "output_dense", None) is not None: - with tf.name_scope(self.output_dense.name): - self.output_dense.build([None, None, self.config.intermediate_size]) - - -class TFWav2Vec2EncoderLayer(keras.layers.Layer): - def __init__(self, config: Wav2Vec2Config, **kwargs): - super().__init__(**kwargs) - self.attention = TFWav2Vec2Attention( - embed_dim=config.hidden_size, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout, - is_decoder=False, - name="attention", - ) - self.dropout = keras.layers.Dropout(config.hidden_dropout) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.feed_forward = TFWav2Vec2FeedForward(config, name="feed_forward") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - attn_residual = hidden_states - hidden_states, attn_weights, _ = self.attention( - hidden_states, attention_mask=attention_mask, training=training - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = attn_residual + hidden_states - - hidden_states = self.layer_norm(hidden_states) - hidden_states = hidden_states + self.feed_forward(hidden_states) - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - if getattr(self, "feed_forward", None) is not None: - with tf.name_scope(self.feed_forward.name): - self.feed_forward.build(None) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.config.hidden_size]) - - -class TFWav2Vec2EncoderLayerStableLayerNorm(keras.layers.Layer): - def __init__(self, config: Wav2Vec2Config, **kwargs): - super().__init__(**kwargs) - self.attention = TFWav2Vec2Attention( - embed_dim=config.hidden_size, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout, - is_decoder=False, - name="attention", - ) - self.dropout = keras.layers.Dropout(config.hidden_dropout) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.feed_forward = TFWav2Vec2FeedForward(config, name="feed_forward") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = False, - training: bool = False, - ) -> tuple[tf.Tensor]: - attn_residual = hidden_states - hidden_states = self.layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.attention( - hidden_states, attention_mask=attention_mask, training=training - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = attn_residual + hidden_states - hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - if getattr(self, "feed_forward", None) is not None: - with tf.name_scope(self.feed_forward.name): - self.feed_forward.build(None) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.config.hidden_size]) - - -class TFWav2Vec2Encoder(keras.layers.Layer): - def __init__(self, config: Wav2Vec2Config, **kwargs): - super().__init__(**kwargs) - self.config = config - self.pos_conv_embed = TFWav2Vec2PositionalConvEmbedding(config, name="pos_conv_embed") - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.dropout = keras.layers.Dropout(config.hidden_dropout) - self.layer = [TFWav2Vec2EncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = False, - output_hidden_states: bool | None = False, - return_dict: bool | None = True, - training: bool | None = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - if attention_mask is not None: - hidden_states = hidden_states * tf.expand_dims(attention_mask, -1) - attention_mask = _expand_mask(attention_mask) - else: - attention_mask = None - - position_embeddings = self.pos_conv_embed(hidden_states) - hidden_states = hidden_states + position_embeddings - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = np.random.uniform(0, 1) - if training and (dropout_probability < self.config.layerdrop): # skip the layer - continue - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "pos_conv_embed", None) is not None: - with tf.name_scope(self.pos_conv_embed.name): - self.pos_conv_embed.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFWav2Vec2EncoderStableLayerNorm(keras.layers.Layer): - def __init__(self, config: Wav2Vec2Config, **kwargs): - super().__init__(**kwargs) - self.config = config - self.pos_conv_embed = TFWav2Vec2PositionalConvEmbedding(config, name="pos_conv_embed") - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.dropout = keras.layers.Dropout(config.hidden_dropout) - self.layer = [ - TFWav2Vec2EncoderLayerStableLayerNorm(config, name=f"layers.{i}") for i in range(config.num_hidden_layers) - ] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = False, - output_hidden_states: bool | None = False, - return_dict: bool | None = True, - training: bool | None = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - if attention_mask is not None: - hidden_states = hidden_states * tf.expand_dims(attention_mask, -1) - attention_mask = _expand_mask(attention_mask) - else: - attention_mask = None - - position_embeddings = self.pos_conv_embed(hidden_states) - hidden_states = hidden_states + position_embeddings - hidden_states = self.dropout(hidden_states, training=training) - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = np.random.uniform(0, 1) - if training and (dropout_probability < self.config.layerdrop): # skip the layer - continue - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - hidden_states = self.layer_norm(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "pos_conv_embed", None) is not None: - with tf.name_scope(self.pos_conv_embed.name): - self.pos_conv_embed.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFWav2Vec2MainLayer(keras.layers.Layer): - config_class = Wav2Vec2Config - - def __init__(self, config: Wav2Vec2Config, **kwargs): - super().__init__(**kwargs) - self.config = config - self.feature_extractor = TFWav2Vec2FeatureEncoder(config, name="feature_extractor") - self.feature_projection = TFWav2Vec2FeatureProjection(config, name="feature_projection") - - if config.do_stable_layer_norm: - self.encoder = TFWav2Vec2EncoderStableLayerNorm(config, name="encoder") - else: - self.encoder = TFWav2Vec2Encoder(config, name="encoder") - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0: - self.masked_spec_embed = self.add_weight( - shape=(self.config.hidden_size,), initializer="uniform", trainable=True, name="masked_spec_embed" - ) - if getattr(self, "feature_extractor", None) is not None: - with tf.name_scope(self.feature_extractor.name): - self.feature_extractor.build(None) - if getattr(self, "feature_projection", None) is not None: - with tf.name_scope(self.feature_projection.name): - self.feature_projection.build(None) - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - - def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor): - """ - Computes the output length of the convolutional layers - """ - - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return (input_length - kernel_size) // stride + 1 - - for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): - input_lengths = _conv_out_length(input_lengths, kernel_size, stride) - - return input_lengths - - def _mask_hidden_states(self, hidden_states: tf.Tensor, mask_time_indices: tf.Tensor | None = None): - """ - Masks extracted features along time axis and/or along feature axis according to - [SpecAugment](https://huggingface.co/papers/1904.08779). - """ - batch_size, sequence_length, hidden_size = shape_list(hidden_states) - - # `config.apply_spec_augment` can set masking to False - if not getattr(self.config, "apply_spec_augment", True): - return hidden_states - - if mask_time_indices is not None: - # apply SpecAugment along time axis with given mask_time_indices - hidden_states = tf.where( - tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool), - self.masked_spec_embed[tf.newaxis, tf.newaxis, :], - hidden_states, - ) - - elif self.config.mask_time_prob > 0: - # generate indices & apply SpecAugment along time axis - mask_time_indices = _compute_mask_indices( - (batch_size, sequence_length), - mask_prob=self.config.mask_time_prob, - mask_length=self.config.mask_time_length, - min_masks=2, - ) - hidden_states = tf.where( - tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool), - self.masked_spec_embed[tf.newaxis, tf.newaxis, :], - hidden_states, - ) - - # apply SpecAugment along feature axis - if self.config.mask_feature_prob > 0: - mask_feature_indices = _compute_mask_indices( - (batch_size, hidden_size), - mask_prob=self.config.mask_feature_prob, - mask_length=self.config.mask_feature_length, - ) - hidden_states = tf.where(mask_feature_indices[:, tf.newaxis, :], hidden_states, 0) - - return hidden_states - - @unpack_inputs - def call( - self, - input_values: tf.Tensor, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - **kwargs: Any, - ): - extract_features = self.feature_extractor(tf.cast(input_values, tf.float32), training=training) - # extract_features = tf.transpose(extract_features, perm=(0, 2, 1)) - - if attention_mask is not None: - # compute real output lengths according to convolution formula - output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1)) - - attention_mask = tf.sequence_mask( - output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype - ) - - hidden_states, extract_features = self.feature_projection(extract_features, training=training) - - mask_time_indices = kwargs.get("mask_time_indices") - if training: - hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) - - encoder_outputs = self.encoder( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = encoder_outputs[0] - - if not return_dict: - return (hidden_states, extract_features) + encoder_outputs[1:] - - return TFWav2Vec2BaseModelOutput( - last_hidden_state=hidden_states, - extract_features=extract_features, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class TFWav2Vec2PreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = Wav2Vec2Config - base_model_prefix = "wav2vec2" - main_input_name = "input_values" - - @property - def input_signature(self): - return { - "input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"), - "attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"), - } - - @property - def dummy_inputs(self): - return { - "input_values": tf.random.uniform(shape=(1, 500), dtype=tf.float32), - "attention_mask": tf.ones(shape=(1, 500), dtype=tf.float32), - } - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - logger.warning( - f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish " - "to train/fine-tune this model, you need a GPU or a TPU" - ) - - def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None): - """ - Computes the output length of the convolutional layers - """ - add_adapter = self.config.add_adapter if add_adapter is None else add_adapter - - def _conv_out_length(input_length, kernel_size, stride): - return tf.math.floordiv(input_length - kernel_size, stride) + 1 - - for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): - input_lengths = _conv_out_length(input_lengths, kernel_size, stride) - - if add_adapter: - for _ in range(self.config.num_adapter_layers): - input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) - return input_lengths - - def _get_feature_vector_attention_mask( - self, feature_vector_length: int, attention_mask: tf.Tensor, add_adapter=None - ): - non_padded_lengths = tf.math.cumsum(attention_mask, axis=-1)[:, -1] - output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) - output_lengths = tf.cast(output_lengths, tf.int32) - batch_size = tf.shape(attention_mask)[0] - # check device here - attention_mask = tf.zeros( - (batch_size, feature_vector_length), dtype=attention_mask.dtype, name="attention_mask" - ) # these two operations makes sure that all values before the output lengths idxs are attended to - ## check device - attention_mask = tf.tensor_scatter_nd_update( - attention_mask, - indices=tf.stack([tf.range(batch_size), output_lengths - 1], axis=1), - updates=tf.ones([batch_size], dtype=attention_mask.dtype), - ) - attention_mask = tf.reverse(attention_mask, axis=[-1]) - attention_mask = tf.cumsum(attention_mask, axis=-1) - attention_mask = tf.reverse(attention_mask, axis=[-1]) - attention_mask = tf.cast(attention_mask, tf.bool) - return attention_mask - - -WAV2VEC2_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_values` only and nothing else: `model(input_values)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_values, attention_mask])` or `model([input_values, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_values": input_values, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -WAV2VEC2_INPUTS_DOCSTRING = r""" - Args: - input_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_values` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_values` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False``): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare TFWav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.", - WAV2VEC2_START_DOCSTRING, -) -class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel): - def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.config = config - self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2") - - @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) - @unpack_inputs - def call( - self, - input_values: tf.Tensor, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - """ - - Returns: - - Example: - - ```python - >>> from transformers import AutoProcessor, TFWav2Vec2Model - >>> from datasets import load_dataset - - >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") - >>> model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") - - - >>> def map_to_array(example): - ... example["speech"] = example["audio"]["array"] - ... return example - - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.map(map_to_array) - - >>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1 - >>> hidden_states = model(input_values).last_hidden_state - ```""" - - output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states - output_attentions = output_attentions if output_attentions else self.config.output_attentions - return_dict = return_dict if return_dict else self.config.return_dict - - outputs = self.wav2vec2( - input_values=input_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "wav2vec2", None) is not None: - with tf.name_scope(self.wav2vec2.name): - self.wav2vec2.build(None) - - -@add_start_docstrings( - """TFWav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", - WAV2VEC2_START_DOCSTRING, -) -class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): - def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2") - self.dropout = keras.layers.Dropout(config.final_dropout) - self.lm_head = keras.layers.Dense(config.vocab_size, name="lm_head") - self.output_hidden_size = ( - config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size - ) - - def freeze_feature_extractor(self): - """ - Calling this function will disable the gradient computation for the feature encoder so that its parameters will - not be updated during training. - """ - warnings.warn( - "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " - "Please use the equivalent `freeze_feature_encoder` method instead.", - FutureWarning, - ) - self.freeze_feature_encoder() - - def freeze_feature_encoder(self): - """ - Calling this function will disable the gradient computation for the feature encoder so that its parameter will - not be updated during training. - """ - self.wav2vec2.feature_extractor.trainable = False - - @unpack_inputs - @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_values: tf.Tensor, - attention_mask: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - labels: tf.Tensor | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> TFCausalLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_values` docstring) Tokens with indices set to `-100` are ignored (masked), - the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - - Returns: - - Example: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoProcessor, TFWav2Vec2ForCTC - >>> from datasets import load_dataset - >>> from torchcodec.decoders import AudioDecoder - - >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") - >>> model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") - - - >>> def map_to_array(example): - ... example["speech"] = example["audio"]["array"] - ... return example - - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.map(map_to_array) - - >>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1 - >>> logits = model(input_values).logits - >>> predicted_ids = tf.argmax(logits, axis=-1) - - >>> transcription = processor.decode(predicted_ids[0]) - - >>> # compute loss - >>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST" - - >>> # Pass transcription as `text` to encode labels - >>> labels = processor(text=transcription, return_tensors="tf").input_ids - - >>> loss = model(input_values, labels=labels).loss - ```""" - if labels is not None and tf.reduce_max(labels) >= self.config.vocab_size: - raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") - - outputs = self.wav2vec2( - input_values=input_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, training=training) - - logits = self.lm_head(hidden_states) - - if labels is not None: - attention_mask = ( - attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32) - ) - input_lengths = self.wav2vec2._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1)) - - # assuming that padded tokens are filled with -100 - # when not being attended to - labels_mask = tf.cast(labels >= 0, tf.int32) - target_lengths = tf.reduce_sum(labels_mask, axis=-1) - - loss = tf.nn.ctc_loss( - logits=logits, - labels=labels, - logit_length=input_lengths, - label_length=target_lengths, - blank_index=self.config.pad_token_id, - logits_time_major=False, - ) - - if self.config.ctc_loss_reduction == "sum": - loss = tf.reduce_sum(loss) - if self.config.ctc_loss_reduction == "mean": - loss = tf.reduce_mean(loss) - - loss = tf.reshape(loss, (1,)) - else: - loss = None - - if not return_dict: - output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "wav2vec2", None) is not None: - with tf.name_scope(self.wav2vec2.name): - self.wav2vec2.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build([None, None, self.output_hidden_size]) - - -class TFWav2Vec2ForSequenceClassification(TFWav2Vec2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2") - self.num_layers = config.num_hidden_layers + 1 - with tf.name_scope(self._name_scope()): - if config.use_weighted_layer_sum: - self.layer_weights = self.add_weight( - shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights" - ) - self.config = config - self.projector = keras.layers.Dense(units=config.classifier_proj_size, name="projector") - self.classifier = keras.layers.Dense(units=config.num_labels, activation=None, name="classifier") - - def freeze_feature_extractor(self): - """ - Calling this function will disable the gradient computation for the feature encoder so that its parameters will - not be updated during training. - """ - warnings.warn( - "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " - "Please use the equivalent `freeze_feature_encoder` method instead.", - FutureWarning, - ) - self.freeze_feature_encoder() - - def freeze_feature_encoder(self): - """ - Calling this function will disable the gradient computation for the feature encoder so that its parameter will - not be updated during training. - """ - self.wav2vec2.feature_extractor.trainable = False - - def freeze_base_model(self): - """ - Calling this function will disable the gradient computation for the base model so that its parameters will not - be updated during training. Only the classification head will be updated. - """ - for layer in self.wav2vec2.layers: - layer.trainable = False - - @unpack_inputs - def call( - self, - input_values: tf.Tensor, - attention_mask: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: tf.Tensor | None = None, - training: bool = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states - - outputs = self.wav2vec2( - input_values, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - if self.config.use_weighted_layer_sum: - hidden_states = outputs[_HIDDEN_STATES_START_POSITION] - hidden_states = tf.stack(hidden_states, axis=1) - norm_weights = tf.nn.softmax(self.layer_weights, axis=-1) - hidden_states = tf.reduce_sum(hidden_states * tf.reshape(norm_weights, [-1, 1, 1]), axis=1) - else: - hidden_states = outputs[0] - - hidden_states = self.projector(hidden_states) - if attention_mask is None: - pooled_output = tf.reduce_mean(hidden_states, axis=1) - else: - padding_mask = self._get_feature_vector_attention_mask(shape_list(hidden_states)[1], attention_mask) - padding_mask_float = tf.cast(padding_mask, hidden_states.dtype) - hidden_states = tf.multiply(hidden_states, tf.expand_dims(padding_mask_float, axis=-1)) - pooled_output = tf.divide( - tf.reduce_sum(hidden_states, axis=1), tf.expand_dims(tf.reduce_sum(padding_mask_float, axis=1), axis=1) - ) - logits = self.classifier(pooled_output) - loss = None - if labels is not None: - loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) - loss = loss_fn(tf.reshape(labels, [-1]), tf.reshape(logits, [-1, self.config.num_labels])) - if not return_dict: - output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "wav2vec2", None) is not None: - with tf.name_scope(self.wav2vec2.name): - self.wav2vec2.build(None) - if getattr(self, "projector", None) is not None: - with tf.name_scope(self.projector.name): - self.projector.build([None, None, self.config.hidden_size]) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.classifier_proj_size]) - - -__all__ = ["TFWav2Vec2ForCTC", "TFWav2Vec2Model", "TFWav2Vec2PreTrainedModel", "TFWav2Vec2ForSequenceClassification"] diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py deleted file mode 100644 index 183fdd58f42c..000000000000 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ /dev/null @@ -1,1707 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax whisper model.""" - -import math -import random -from functools import partial -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen import partitioning as nn_partitioning -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...generation.flax_logits_process import FlaxWhisperTimeStampLogitsProcessor -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxSeq2SeqLMOutput, - FlaxSeq2SeqModelOutput, - FlaxSequenceClassifierOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_whisper import WhisperConfig - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "openai/whisper-tiny" -_CONFIG_FOR_DOC = "WhisperConfig" - -remat = nn_partitioning.remat - - -def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array: - """Returns sinusoids for positional embedding""" - length, channels = shape - if channels % 2 != 0: - raise ValueError( - f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." - ) - log_timescale_increment = math.log(10000) / (channels // 2 - 1) - inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2)) - scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1) - return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype) - - -WHISPER_START_DOCSTRING = r""" - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - Finally, this model supports inherent JAX features such as: - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`WhisperConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision - inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`. - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] - and [`~FlaxPreTrainedModel.to_bf16`]. -""" - -WHISPER_INPUTS_DOCSTRING = r""" - Args: - input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): - Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by - loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a - `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library - (`pip install soundfile`). - To prepare the array into `input_features`, the [`WhisperFeatureExtractor`] should be used for extracting - the features, padding and conversion into a tensor of type `numpy.ndarray`. - See [`~WhisperFeatureExtractor.__call__`] - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but - is not used. By default the silence in the input log mel spectrogram are ignored. - decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using - [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - [What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as - the starting token for `decoder_input_ids` generation. - decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 - in [the paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't - use masking, but this argument is preserved for compatibility. By default the silence in the input log mel - spectrogram are ignored. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -WHISPER_ENCODE_INPUTS_DOCSTRING = r""" - Args: - input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): - Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by - loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via - the torchcodec library (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). - To prepare the array into `input_features`, the [`WhisperFeatureExtractor`] should be used for extracting - the mel features, padding and conversion into a tensor of type `numpy.ndarray`. - See [`~WhisperFeatureExtractor.__call__`]. - attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but - is not used. By default the silence in the input log mel spectrogram are ignored. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -WHISPER_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`): - Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using - [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - [What are decoder input IDs?](../glossary#decoder-input-ids) - encoder_outputs (`tuple(tuple(numpy.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, - but it is not used. By default the silence in the input log mel spectrogram are ignored. - decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 - in [the paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - past_key_values (`dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class FlaxWhisperAttention(nn.Module): - config: WhisperConfig - embed_dim: int - num_heads: int - dropout: float = 0.0 - causal: bool = False - bias: bool = True - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {self.num_heads})." - ) - - dense = partial( - nn.Dense, - self.embed_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.q_proj = dense(use_bias=self.bias) - self.k_proj = dense(use_bias=False) - self.v_proj = dense(use_bias=self.bias) - self.out_proj = dense(use_bias=self.bias) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_target_positions), dtype="bool"), dtype="bool" - ) - - def __call__( - self, - hidden_states: jnp.ndarray, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - query_states = self.q_proj(hidden_states) - - if is_cross_attention: - key_states = self.k_proj(key_value_states) - value_states = self.v_proj(key_value_states) - else: - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, - (0, 0, mask_shift, 0), - (1, 1, query_length, max_decoder_length), - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - def _split_heads(self, hidden_state) -> jnp.ndarray: - return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_state) -> jnp.ndarray: - return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only - # attend to those key positions that have already been generated and cached, not the - # remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - - return key, value, attention_mask - - -# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Whisper -class FlaxWhisperEncoderLayer(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxWhisperAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.encoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class FlaxWhisperEncoderLayerCollection(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - if self.gradient_checkpointing: - FlaxWhisperEncoderCheckpointLayer = remat(FlaxWhisperEncoderLayer, static_argnums=(2, 3)) - self.layers = [ - FlaxWhisperEncoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.encoder_layers) - ] - else: - self.layers = [ - FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.encoder_layers) - ] - self.layerdrop = self.config.encoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): # skip the layer - layer_outputs = (None, None) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions, - deterministic, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - -# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Whisper -class FlaxWhisperDecoderLayer(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxWhisperAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - causal=True, - dtype=self.dtype, - ) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.encoder_attn = FlaxWhisperAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.fc1 = nn.Dense( - self.config.decoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs - - -class FlaxWhisperDecoderLayerCollection(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - if self.gradient_checkpointing: - FlaxWhisperDecoderCheckpointLayer = remat(FlaxWhisperDecoderLayer, static_argnums=(4, 5, 6)) - self.layers = [ - FlaxWhisperDecoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.decoder_layers) - ] - else: - self.layers = [ - FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.decoder_layers) - ] - self.layerdrop = self.config.decoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): - layer_outputs = (None, None, None) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - init_cache, - output_attentions, - deterministic, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class FlaxWhisperEncoder(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self) -> None: - self.conv1 = nn.Conv( - self.config.d_model, - kernel_size=(3,), - padding=1, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - self.conv2 = nn.Conv( - self.config.d_model, - kernel_size=(3,), - strides=2, - padding=1, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, - ) - - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - self.layers = FlaxWhisperEncoderLayerCollection( - self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - self.embed_positions = nn.Embed( - self.config.max_source_positions, - self.config.d_model, - dtype=self.dtype, - embedding_init=sinusoidal_embedding_init, - ) - - self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_features: jnp.ndarray, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - if input_features.shape[1:] != (self.config.num_mel_bins, self.config.max_source_positions * 2): - raise ValueError( - "input_features.shape[1:], must be equal to (self.config.num_mel_bins," - f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be" - f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))" - ) - - input_features = input_features.transpose(0, 2, 1) - hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False) - hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) - - embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions)) - # freeze the sinusoidal embeddings by stopping the back-prop - embed_positions = jax.lax.stop_gradient(embed_positions) - hidden_states = hidden_states + embed_positions - - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask=None, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_states = outputs[0] - last_hidden_states = self.layer_norm(last_hidden_states) - - # update the last element in `hidden_states` after applying `layernorm` above - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_states,) - - if not return_dict: - outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=last_hidden_states, - hidden_states=hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxWhisperDecoder(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self) -> None: - self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype) - self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype) - - self.layers = FlaxWhisperDecoderLayerCollection( - self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-5) - - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: jnp.ndarray, - position_ids: jnp.ndarray, - encoder_hidden_states: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - input_embeds = self.embed_tokens(input_ids) - position_embeds = self.embed_positions(position_ids) - - hidden_states = input_embeds + position_embeds - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_states = outputs[0] - last_hidden_states = self.layer_norm(last_hidden_states) - - # update the last element in `hidden_states` after applying `layernorm` above - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_states,) - - if not return_dict: - outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=last_hidden_states, - hidden_states=hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -class FlaxWhisperModule(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self) -> None: - self.encoder = FlaxWhisperEncoder( - self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.decoder = FlaxWhisperDecoder( - self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - - def __call__( - self, - input_features: jnp.ndarray, - decoder_input_ids: jnp.ndarray, - decoder_attention_mask: jnp.ndarray, - decoder_position_ids: jnp.ndarray, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - encoder_outputs = self.encoder( - input_features, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - -class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel): - config_class = WhisperConfig - base_model_prefix: str = "model" - main_input_name = "input_features" - module_class: nn.Module = None - - def __init__( - self, - config: WhisperConfig, - input_shape: Optional[tuple[int]] = None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - gradient_checkpointing: bool = False, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) - if input_shape is None: - input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def enable_gradient_checkpointing(self): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=True, - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_features = jnp.zeros(input_shape, dtype="f4") - input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) - - decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_features=input_features, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->Whisper - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig) - def encode( - self, - input_features: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - **kwargs, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration - >>> from datasets import load_dataset - - >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") - >>> input_features = inputs.input_features - >>> encoder_outputs = model.encode(input_features=input_features) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_features, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_features, **kwargs) - - return self.module.apply( - {"params": params or self.params}, - input_features=jnp.array(input_features, dtype="f4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=WhisperConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration - >>> from datasets import load_dataset - >>> import jax.numpy as jnp - - >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> input_features = processor(ds[0]["audio"]["array"], return_tensors="np").input_features - - >>> encoder_outputs = model.encode(input_features=input_features) - >>> decoder_start_token_id = model.config.decoder_start_token_id - - >>> decoder_input_ids = jnp.ones((input_features.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> last_decoder_hidden_states = outputs.last_hidden_state - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - if decoder_attention_mask is not None: - decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 - else: - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxWhisperAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) - def __call__( - self, - input_features: jnp.ndarray, - decoder_input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare decoder inputs - if decoder_position_ids is None: - if decoder_attention_mask is not None: - decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 - else: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_features=jnp.array(input_features, dtype="f4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - -@add_start_docstrings( - "The bare Whisper Model transformer outputting raw hidden-states without any specific head on top.", - WHISPER_START_DOCSTRING, -) -class FlaxWhisperModel(FlaxWhisperPreTrainedModel): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - module_class = FlaxWhisperModule - - -append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) - - -class FlaxWhisperForConditionalGenerationModule(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self) -> None: - self.model = FlaxWhisperModule( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_features, - decoder_input_ids, - decoder_attention_mask: jnp.ndarray = None, - decoder_position_ids: jnp.ndarray = None, - position_ids: jnp.ndarray = None, - attention_mask: jnp.ndarray = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_features=input_features, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return output - - return FlaxSeq2SeqLMOutput( - logits=lm_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING) -class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): - module_class = FlaxWhisperForConditionalGenerationModule - dtype: jnp.dtype = jnp.float32 - - @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[dict] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration - >>> from datasets import load_dataset - - >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") - >>> input_features = inputs.input_features - >>> encoder_outputs = model.encode(input_features=input_features) - >>> decoder_start_token_id = model.config.decoder_start_token_id - - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> last_decoder_hidden_states = outputs.last_hidden_state - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - if decoder_attention_mask is not None: - decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 - else: - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4") - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxWhisperAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - outputs = decoder_module( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - **kwargs, - ) - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"] - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = module.lm_head(hidden_states) - - return lm_logits, outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def generate( - self, - input_features, - generation_config=None, - logits_processor=None, - return_timestamps=None, - task=None, - language=None, - is_multilingual=None, - **kwargs, - ): - if generation_config is None: - generation_config = self.generation_config - - if return_timestamps is not None: - generation_config.return_timestamps = return_timestamps - - if task is not None: - generation_config.task = task - - if is_multilingual is not None: - generation_config.is_multilingual = is_multilingual - - if language is not None: - generation_config.language = language - - if kwargs is not None and "decoder_input_ids" in kwargs: - decoder_input_length = len(kwargs["decoder_input_ids"]) - else: - decoder_input_length = 1 - - forced_decoder_ids = [] - - if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual: - if hasattr(generation_config, "language"): - forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language])) - else: - forced_decoder_ids.append((1, None)) - - if hasattr(generation_config, "task"): - forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) - else: - forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) - - if ( - hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps - ) or return_timestamps: - logits_processor = [ - FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length) - ] - else: - if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id: - idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 - forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) - - if len(forced_decoder_ids) > 0: - generation_config.forced_decoder_ids = forced_decoder_ids - - return super().generate( - input_features, - generation_config, - logits_processor=logits_processor, - **kwargs, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jax.Array] = None, - decoder_attention_mask: Optional[jax.Array] = None, - encoder_outputs=None, - **kwargs, - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - position_ids = decoder_attention_mask.cumsum(-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 - return model_kwargs - - -FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r""" - Returns: - - Transcription example: - - ```python - >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration - >>> from datasets import load_dataset - - >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") - >>> input_features = inputs.input_features - >>> generated_ids = model.generate(input_ids=input_features) - >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - >>> transcription - ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' - ``` -""" - -overwrite_call_docstring( - FlaxWhisperForConditionalGeneration, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING -) -append_replace_return_docstrings( - FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC -) - - -class FlaxWhisperForAudioClassificationModule(nn.Module): - config: WhisperConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self) -> None: - self.encoder = FlaxWhisperEncoder( - config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.config.is_encoder_decoder = False - num_layers = self.config.num_hidden_layers + 1 - if self.config.use_weighted_layer_sum: - self.layer_weights = jnp.repeat(1 / num_layers, num_layers) - self.projector = nn.Dense(self.config.classifier_proj_size, dtype=self.dtype) - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_features, - encoder_outputs=None, - output_attentions=None, - output_hidden_states: bool = True, - return_dict: bool = True, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_features, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if self.config.use_weighted_layer_sum: - hidden_states = jnp.stack(encoder_outputs, axis=1) - norm_weights = jax.nn.softmax(self.layer_weights, axis=-1) - hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1) - else: - hidden_states = encoder_outputs[0] - - hidden_states = self.projector(hidden_states) - pooled_output = jnp.mean(hidden_states, axis=1) - - logits = self.classifier(pooled_output) - - if not return_dict: - return (logits,) + encoder_outputs[1:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -@add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING) -class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel): - module_class = FlaxWhisperForAudioClassificationModule - dtype: jnp.dtype = jnp.float32 - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_features = jnp.zeros(input_shape, dtype="f4") - input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - input_features=input_features, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) - def __call__( - self, - input_features: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - dropout_rng: PRNGKey = None, - **kwargs, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - return self.module.apply( - {"params": params or self.params}, - input_features=jnp.array(input_features, dtype="f4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - ) - - -FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING = r""" - Returns: - - Transcription example: - - ```python - >>> import jax.numpy as jnp - >>> from transformers import AutoFeatureExtractor, FlaxWhisperForAudioClassification - >>> from datasets import load_dataset - - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") - >>> model = FlaxWhisperForAudioClassification.from_pretrained( - ... "sanchit-gandhi/whisper-medium-fleurs-lang-id", from_pt=True - ... ) - >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) - - >>> sample = next(iter(ds)) - - >>> inputs = feature_extractor( - ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="np" - ... ) - >>> input_features = inputs.input_features - - >>> logits = model(input_features).logits - - >>> predicted_class_ids = jnp.argmax(logits).item() - >>> predicted_label = model.config.id2label[predicted_class_ids] - >>> predicted_label - 'af_za' - ``` -""" - -overwrite_call_docstring( - FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING -) -append_replace_return_docstrings( - FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC -) - - -__all__ = [ - "FlaxWhisperForConditionalGeneration", - "FlaxWhisperModel", - "FlaxWhisperPreTrainedModel", - "FlaxWhisperForAudioClassification", -] diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py deleted file mode 100644 index c768db3c3070..000000000000 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ /dev/null @@ -1,1754 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TensorFlow Whisper model.""" - -from __future__ import annotations - -import math -import random - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...generation.configuration_utils import GenerationConfig -from ...generation.tf_logits_process import TFLogitsProcessorList -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPastAndCrossAttentions, - TFSeq2SeqLMOutput, - TFSeq2SeqModelOutput, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_whisper import WhisperConfig -from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "WhisperConfig" - - -LARGE_NEGATIVE = -1e8 - - -def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor: - """Returns sinusoids for positional embedding""" - length, channels = shape - if channels % 2 != 0: - raise ValueError( - f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." - ) - log_timescale_increment = math.log(10000) / (channels // 2 - 1) - inv_timescales = tf.exp(-log_timescale_increment * tf.range(channels // 2, dtype=tf.float32)) - scaled_time = tf.reshape(tf.range(length, dtype=tf.float32), (-1, 1)) * tf.reshape(inv_timescales, (1, -1)) - return tf.cast(tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1), dtype) - - -# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right -def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - pad_token_id = tf.cast(pad_token_id, input_ids.dtype) - decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) - start_tokens = tf.fill( - (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) - ) - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz = input_ids_shape[0] - tgt_len = input_ids_shape[1] - mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE - mask_cond = tf.range(shape_list(mask)[-1]) - - mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) - - return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -class TFWhisperPositionalEmbedding(keras.layers.Layer): - def __init__( - self, - num_positions: int, - embedding_dim: int, - padding_idx: int | None = None, - embedding_initializer=None, - **kwargs, - ): - super().__init__(**kwargs) - self.num_positions = num_positions - self.embedding_dim = embedding_dim - self.padding_idx = padding_idx - self.embedding_initializer = keras.initializers.get(embedding_initializer) - - def build(self, input_shape): - self.weight = self.add_weight( - name="weight", - shape=[self.num_positions, self.embedding_dim], - initializer=self.embedding_initializer, - trainable=True, - ) - super().build(input_shape) - - def call(self, input_ids, past_key_values_length=0): - past_key_values_length = tf.cast(past_key_values_length, tf.int32) - gather_indices = tf.range(tf.shape(input_ids)[1], delta=1) + past_key_values_length - return tf.gather(self.weight, gather_indices) - - -class TFWhisperAttention(keras.layers.Layer): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=False, name="k_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention._shape with BART->whisper - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention.call with BART->whisper - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.speech_to_text.modeling_tf_speech_to_text.TFSpeech2TextEncoderLayer with Speech2Text->Whisper -class TFWhisperEncoderLayer(keras.layers.Layer): - def __init__(self, config: WhisperConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFWhisperAttention( - self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" - ) - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training: bool = False - ): - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)` - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - training=training, - ) - - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) - - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - return hidden_states, self_attn_weights - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.encoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -# Copied from transformers.models.speech_to_text.modeling_tf_speech_to_text.TFSpeech2TextDecoderLayer with Speech2Text->Whisper -class TFWhisperDecoderLayer(keras.layers.Layer): - def __init__(self, config: WhisperConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - - self.self_attn = TFWhisperAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - name="self_attn", - is_decoder=True, - ) - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.encoder_attn = TFWhisperAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - name="encoder_attn", - is_decoder=True, - ) - self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") - self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - def call( - self, - hidden_states, - attention_mask: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - cross_attn_layer_head_mask: tf.Tensor | None = None, - past_key_value: tuple[tf.Tensor] | None = None, - training=False, - ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`tf.Tensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - encoder_hidden_states (`tf.Tensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` - encoder_attention_mask (`tf.Tensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - `(decoder_attention_heads,)` - cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. - `(decoder_attention_heads,)` - past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - training=training, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - training=training, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - return ( - hidden_states, - self_attn_weights, - cross_attn_weights, - present_key_value, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "encoder_attn", None) is not None: - with tf.name_scope(self.encoder_attn.name): - self.encoder_attn.build(None) - if getattr(self, "encoder_attn_layer_norm", None) is not None: - with tf.name_scope(self.encoder_attn_layer_norm.name): - self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.decoder_ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - - -class TFWhisperPreTrainedModel(TFPreTrainedModel): - config_class = WhisperConfig - base_model_prefix = "model" - main_input_name = "input_features" - - def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor) -> int: - """ - Computes the output length of the convolutional layers - """ - input_lengths = (input_lengths - 1) // 2 + 1 - - return input_lengths - - @property - def dummy_inputs(self) -> dict[str, tf.Tensor]: - """ - Dummy inputs to build the network. - - Returns: - `dict[str, tf.Tensor]`: The dummy inputs. - """ - return { - self.main_input_name: tf.random.uniform( - [1, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32 - ), - "decoder_input_ids": tf.constant([[1, 3]], dtype=tf.int32), - } - - @property - def input_signature(self): - return { - "input_features": tf.TensorSpec((None, self.config.num_mel_bins, None), tf.float32, name="input_features"), - "decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"), - "decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"), - } - - -WHISPER_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`WhisperConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -WHISPER_INPUTS_DOCSTRING = r""" - Args: - input_features (`tf.Tensor` of shape `(batch_size, feature_size, sequence_length)`): - Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained - by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a - `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or the soundfile library - (`pip install soundfile`). - To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the - fbank features, padding and conversion into a tensor of type `tf.Tensor`. - See [`~WhisperFeatureExtractor.__call__`] - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`SpeechToTextTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - SpeechToText uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If - `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should read - [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the - paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - decoder_inputs_embeds (`tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be - input (see `past_key_values`). This is useful if you want more control over how to convert - `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@keras_serializable -class TFWhisperEncoder(keras.layers.Layer): - config_class = WhisperConfig - """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`TFWhisperEncoderLayer`]. - - Args: - config: WhisperConfig - embed_tokens (TFWhisperEmbedding): output embedding - """ - - def __init__(self, config: WhisperConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layerdrop = config.encoder_layerdrop - - self.embed_dim = config.d_model - self.num_mel_bins = config.num_mel_bins - self.padding_idx = config.pad_token_id - self.max_source_positions = config.max_source_positions - self.embed_scale = math.sqrt(self.embed_dim) if config.scale_embedding else 1.0 - - # Padding is added in call() to match the PyTorch implementation - self.conv1 = keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=1, padding="valid", name="conv1") - self.conv2 = keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2") - - self.embed_positions = TFWhisperPositionalEmbedding( - num_positions=self.max_source_positions, - embedding_dim=self.embed_dim, - embedding_initializer=sinusoidal_embedding_init, - name="embed_positions", - ) - self.embed_positions.trainable = False - - self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - - self.dropout = keras.layers.Dropout(config.dropout) - - @unpack_inputs - def call( - self, - input_features=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - r""" - Args: - input_features (`tf.Tensor` of shape `(batch_size, feature_size, sequence_length)`): - Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a - `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or - the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the fbank features, - padding and conversion into a tensor of type `tf.Tensor`. See [`~WhisperFeatureExtractor.__call__`] - head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # TF 2.0 layers can't use channels first format when running on CPU. - input_features = tf.transpose(input_features, perm=(0, 2, 1)) - input_features = tf.pad(input_features, [[0, 0], [1, 1], [0, 0]]) - inputs_embeds = keras.activations.gelu(self.conv1(input_features)) - inputs_embeds = tf.pad(inputs_embeds, [[0, 0], [1, 1], [0, 0]]) - inputs_embeds = keras.activations.gelu(self.conv2(inputs_embeds)) - inputs_embeds = tf.transpose(inputs_embeds, perm=(0, 1, 2)) - - embed_pos = self.embed_positions(input_ids=tf.zeros((1, self.max_source_positions), dtype=tf.int32)) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.dropout(hidden_states, training=training) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - tf.debugging.assert_equal( - shape_list(head_mask)[0], - len(self.encoder_layers), - message=( - f"The head_mask should be specified for {len(self.encoder_layers)} layers, but it is for" - f" {shape_list(head_mask)[0]}." - ), - ) - - for idx, encoder_layer in enumerate(self.encoder_layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): # skip the layer - continue - - hidden_states, attn = encoder_layer( - hidden_states, - None, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - training=training, - ) - - if output_attentions: - all_attentions += (attn,) - - hidden_states = self.layer_norm(hidden_states) - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "conv1", None) is not None: - with tf.name_scope(self.conv1.name): - self.conv1.build([None, None, self.num_mel_bins]) - if getattr(self, "conv2", None) is not None: - with tf.name_scope(self.conv2.name): - self.conv2.build([None, None, self.embed_dim]) - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "encoder_layers", None) is not None: - for layer in self.encoder_layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -class TFWhisperDecoder(keras.layers.Layer): - config_class = WhisperConfig - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFWhisperDecoderLayer`] - - Args: - config: WhisperConfig - """ - - def __init__(self, config: WhisperConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.dropout = keras.layers.Dropout(config.dropout) - self.layerdrop = config.decoder_layerdrop - self.padding_idx = config.pad_token_id - self.max_target_positions = config.max_target_positions - self.max_source_positions = config.max_source_positions - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - - self.embed_tokens = keras.layers.Embedding( - input_dim=config.vocab_size, - output_dim=config.d_model, - embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), - name="embed_tokens", - ) - self.embed_positions = TFWhisperPositionalEmbedding( - self.max_target_positions, config.d_model, name="embed_positions" - ) - - self.decoder_layers = [TFWhisperDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] - - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - batch_size, seq_len = input_shape[0], input_shape[1] - - combined_attention_mask = tf.cond( - tf.math.greater(seq_len, 1), - lambda: _make_causal_mask(input_shape, past_key_values_length=past_key_values_length), - lambda: _expand_mask(tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len), - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - return combined_attention_mask - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - encoder_hidden_states=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = tf.shape(input_ids) - input_ids = tf.reshape(input_ids, (-1, input_shape[-1])) - elif inputs_embeds is not None: - input_shape = tf.shape(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - # past_key_values_length - past_key_values_length = tf.shape(past_key_values[0][0])[2] if past_key_values is not None else 0 - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) - inputs_embeds = self.embed_tokens(input_ids) - - attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length) - - # embed positions - filled_past_positions = past_key_values_length if position_ids is None else position_ids[0, -1] - positions = self.embed_positions(input_ids, past_key_values_length=filled_past_positions) - - hidden_states = inputs_embeds + positions - hidden_states = self.dropout(hidden_states, training=training) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None - - # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired - for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: - if attn_mask is not None: - tf.debugging.assert_equal( - shape_list(attn_mask)[0], - len(self.decoder_layers), - message=( - f"The {attn_mask_name} should be specified for {len(self.decoder_layers)} layers, but it is" - f" for {shape_list(attn_mask)[0]}." - ), - ) - - for idx, decoder_layer in enumerate(self.decoder_layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, - training=training, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[3],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - hidden_states = self.layer_norm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] - if v is not None - ) - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "embed_tokens", None) is not None: - with tf.name_scope(self.embed_tokens.name): - self.embed_tokens.build(None) - if getattr(self, "embed_positions", None) is not None: - with tf.name_scope(self.embed_positions.name): - self.embed_positions.build(None) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "decoder_layers", None) is not None: - for layer in self.decoder_layers: - with tf.name_scope(layer.name): - layer.build(None) - - -@add_start_docstrings( - "The bare Whisper Model outputting raw hidden-states without any specific head on top.", - WHISPER_START_DOCSTRING, -) -@keras_serializable -class TFWhisperMainLayer(keras.layers.Layer): - config_class = WhisperConfig - - def __init__(self, config: WhisperConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.encoder = TFWhisperEncoder(config, name="encoder") - self.decoder = TFWhisperDecoder(config, name="decoder") - - def get_input_embeddings(self): - return self.decoder.embed_tokens - - def set_input_embeddings(self, value): - self.decoder.embed_tokens = value - - def get_encoder(self): - return self.encoder - - @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - @unpack_inputs - def call( - self, - input_features=None, - decoder_input_ids=None, - decoder_attention_mask=None, - decoder_position_ids=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - encoder_outputs=None, - past_key_values=None, - decoder_inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ): - r""" - Returns: - - Example: - - ```python - >>> import tensorflow as tf - >>> from transformers import TFWhisperModel, AutoFeatureExtractor - >>> from datasets import load_dataset - - >>> model = TFWhisperModel.from_pretrained("openai/whisper-base") - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="tf") - >>> input_features = inputs.input_features - >>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id - >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state - >>> list(last_hidden_state.shape) - [1, 2, 512] - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_features, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): - encoder_outputs = TFBaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return TFSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "decoder", None) is not None: - with tf.name_scope(self.decoder.name): - self.decoder.build(None) - - -@add_start_docstrings( - "The bare Whisper Model outputting raw hidden-states without any specific head on top.", - WHISPER_START_DOCSTRING, -) -class TFWhisperModel(TFWhisperPreTrainedModel): - def __init__(self, config: WhisperConfig, **kwargs): - super().__init__(config, **kwargs) - - self.model = TFWhisperMainLayer(config, name="model") - - def get_input_embeddings(self): - return self.model.decoder.embed_tokens - - def set_input_embeddings(self, value): - self.model.decoder.embed_tokens = value - - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - - def decoder(self): - return self.model.decoder - - def encoder(self): - return self.model.encoder - - @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) - @unpack_inputs - def call( - self, - input_features: TFModelInputType | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - decoder_inputs_embeds: tuple[np.ndarray | tf.Tensor] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFSeq2SeqModelOutput: - r""" - Returns: - - Example: - - ```python - >>> import tensorflow as tf - >>> from transformers import TFWhisperModel, AutoFeatureExtractor - >>> from datasets import load_dataset - - >>> model = TFWhisperModel.from_pretrained("openai/whisper-base") - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="tf") - >>> input_features = inputs.input_features - >>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id - >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state - >>> list(last_hidden_state.shape) - [1, 2, 512] - ```""" - outputs = self.model( - input_features=input_features, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqModelOutput( - last_hidden_state=output.last_hidden_state, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -@add_start_docstrings( - "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.", - WHISPER_START_DOCSTRING, -) -class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLanguageModelingLoss): - base_model_prefix = "model" - _keys_to_ignore_on_load_missing = [ - r"encoder.version", - r"decoder.version", - r"proj_out.weight", - ] - _keys_to_ignore_on_save = [ - r"proj_out.weight", - ] - - def __init__(self, config: WhisperConfig, **kwargs): - super().__init__(config, **kwargs) - self.model = TFWhisperMainLayer(config, name="model") - - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - def resize_token_embeddings(self, new_num_tokens: int) -> keras.layers.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - return new_embeddings - - @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - @unpack_inputs - def call( - self, - input_features: TFModelInputType | None = None, - decoder_input_ids: np.ndarray | tf.Tensor | None = None, - decoder_attention_mask: np.ndarray | tf.Tensor | None = None, - decoder_position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - decoder_head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - encoder_outputs: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - decoder_inputs_embeds: tuple[np.ndarray | tf.Tensor] | None = None, - labels: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> tuple[tf.Tensor] | TFSeq2SeqLMOutput: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` - or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is - only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoProcessor, TFWhisperForConditionalGeneration - >>> from datasets import load_dataset - - >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - - >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf") - >>> input_features = inputs.input_features - - >>> generated_ids = model.generate(input_features=input_features) - - >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - >>> transcription - ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if labels is not None: - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.model( - input_features, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - decoder_last_hidden_state = outputs[0] - # Decoder and encoder embeddings are tied - lm_logits = tf.matmul(decoder_last_hidden_state, self.get_output_embeddings().weights, transpose_b=True) - - loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSeq2SeqLMOutput( - loss=loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - def generate( - self, - inputs: tf.Tensor | None = None, - generation_config: GenerationConfig | None = None, - logits_processor: TFLogitsProcessorList | None = None, - seed: list[int] | None = None, - return_timestamps: bool | None = None, - task: str | None = None, - language: str | None = None, - is_multilingual: bool | None = None, - prompt_ids: tf.Tensor | None = None, - return_token_timestamps=None, - **kwargs, - ): - r""" - Generates sequences of token ids for models with a language modeling head. - - - - Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the - model's default generation configuration. You can override any `generation_config` by passing the corresponding - parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - - For an overview of generation strategies and code examples, check out the [following - guide](../generation_strategies). - - - - Parameters: - inputs (`tf.Tensor` of varying shape depending on the modality, *optional*): - The sequence used as a prompt for the generation or as model inputs to the encoder. If unset the method - initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in - the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, - `input_values`, `input_features`, or `pixel_values`. - generation_config (`~generation.GenerationConfig`, *optional*): - The generation configuration to be used as base parametrization for the generation call. `**kwargs` - passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which had the following loading - priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model - configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s - default values, whose documentation should be checked to parameterize generation. - logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and - generation config. If a logit processor is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - seed (`list[int]`, *optional*): - Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the - `seed` argument from stateless functions in `tf.random`. - return_timestamps (`bool`, *optional*): - Whether to return the timestamps with the text. This enables the `TFWhisperTimestampsLogitsProcessor`. - task (`str`, *optional*): - Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids` - will be updated accordingly. - language (`str`, *optional*): - Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can - find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary. - is_multilingual (`bool`, *optional*): - Whether or not the model is multilingual. - prompt_ids (`tf.Tensor`, *optional*): - Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is - provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for - transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words - correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. - return_token_timestamps (`bool`, *optional*): - Whether to return token-level timestamps with the text. This can be used with or without the - `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into - words. - kwargs (`dict[str, Any]`, *optional*): - Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be - forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder - specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. - - Return: - [`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when - `config.return_dict_in_generate=True`) or a `tf.Tensor`. - - If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.TFGreedySearchDecoderOnlyOutput`], - - [`~generation.TFSampleDecoderOnlyOutput`], - - [`~generation.TFBeamSearchDecoderOnlyOutput`], - - [`~generation.TFBeamSampleDecoderOnlyOutput`] - - If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.TFGreedySearchEncoderDecoderOutput`], - - [`~generation.TFSampleEncoderDecoderOutput`], - - [`~generation.TFBeamSearchEncoderDecoderOutput`], - - [`~generation.TFBeamSampleEncoderDecoderOutput`] - - """ - if generation_config is None: - generation_config = self.generation_config - - if return_timestamps is not None: - if not hasattr(generation_config, "no_timestamps_token_id"): - raise ValueError( - "You are trying to return timestamps, but the generation config is not properly set. " - "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " - "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" - ) - - generation_config.return_timestamps = return_timestamps - else: - generation_config.return_timestamps = False - - if language is not None: - language = language.lower() - generation_config.language = language - if task is not None: - generation_config.task = task - - forced_decoder_ids = None - - # Legacy code for backward compatibility - if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: - forced_decoder_ids = self.config.forced_decoder_ids - elif ( - hasattr(self.generation_config, "forced_decoder_ids") - and self.generation_config.forced_decoder_ids is not None - ): - forced_decoder_ids = self.generation_config.forced_decoder_ids - else: - forced_decoder_ids = kwargs.get("forced_decoder_ids") - - if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): - forced_decoder_ids = [] - if hasattr(generation_config, "language"): - if generation_config.language in generation_config.lang_to_id: - language_token = generation_config.language - elif generation_config.language in TO_LANGUAGE_CODE: - language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" - elif generation_config.language in TO_LANGUAGE_CODE.values(): - language_token = f"<|{generation_config.language}|>" - else: - is_language_code = len(generation_config.language) == 2 - raise ValueError( - f"Unsupported language: {generation_config.language}. Language should be one of:" - f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." - ) - if language_token not in generation_config.lang_to_id: - raise ValueError( - f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`." - "(You should just add it to the generation config)" - ) - forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) - else: - forced_decoder_ids.append((1, None)) # automatically detect the language - - if hasattr(generation_config, "task"): - if generation_config.task in TASK_IDS: - forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) - else: - raise ValueError( - f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" - ) - elif hasattr(generation_config, "task_to_id"): - forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe - if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: - idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 - forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) - - if forced_decoder_ids is not None: - generation_config.forced_decoder_ids = forced_decoder_ids - - if prompt_ids is not None: - if kwargs.get("decoder_start_token_id") is not None: - raise ValueError( - "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten." - ) - prompt_ids = prompt_ids.tolist() - decoder_start_token_id, *text_prompt_ids = prompt_ids - # Slicing the text prompt ids in a manner consistent with the OpenAI implementation - # to accommodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) - text_prompt_ids = text_prompt_ids[-self.config.max_length // 2 - 1 :] - # Set the decoder_start_token_id to <|startofprev|> - kwargs.update({"decoder_start_token_id": decoder_start_token_id}) - - # Update the max generation length to include the prompt - specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None) - default_max_length = generation_config.max_new_tokens or generation_config.max_length - non_prompt_max_length = specified_max_length or default_max_length - kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids) - - # Reformat the forced_decoder_ids to incorporate the prompt - non_prompt_forced_decoder_ids = ( - kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids - ) - forced_decoder_ids = [ - *text_prompt_ids, - generation_config.decoder_start_token_id, - *[token for _rank, token in non_prompt_forced_decoder_ids], - ] - forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] - generation_config.forced_decoder_ids = forced_decoder_ids - - # TODO: Implement `WhisperTimeStampLogitsProcessor`. - if generation_config.return_timestamps: - # logits_processor = [TFWhisperTimeStampLogitsProcessor(generation_config)] - raise ValueError("`TFWhisperForConditionalGeneration` doesn't support returning the timestamps yet.") - - if return_token_timestamps: - kwargs["output_attentions"] = True - kwargs["return_dict_in_generate"] = True - - if getattr(generation_config, "task", None) == "translate": - logger.warning("Token-level timestamps may not be reliable for task 'translate'.") - if not hasattr(generation_config, "alignment_heads"): - raise ValueError( - "Model generation config has no `alignment_heads`, token-level timestamps not available. " - "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." - ) - - outputs = super().generate( - inputs, - generation_config, - logits_processor, - **kwargs, - ) - - if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads) - - return outputs - - def serving_output(self, output): - pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - use_cache=None, - encoder_outputs=None, - attention_mask=None, - decoder_attention_mask=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - if decoder_attention_mask is not None: # xla - decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past_key_values is not None: # no xla + past - decoder_position_ids = past_key_values[0][0].shape[2] - else: # no xla + no past - decoder_position_ids = tf.range(decoder_input_ids.shape[1]) - decoder_position_ids = tf.broadcast_to(decoder_position_ids, decoder_input_ids.shape) - - return { - "input_features": None, # Needs to be passed to make Keras.layer.__call__ happy - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "use_cache": use_cache, - "decoder_attention_mask": decoder_attention_mask, - "decoder_position_ids": decoder_position_ids, - } - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -__all__ = ["TFWhisperForConditionalGeneration", "TFWhisperModel", "TFWhisperPreTrainedModel"] diff --git a/src/transformers/models/xglm/modeling_flax_xglm.py b/src/transformers/models/xglm/modeling_flax_xglm.py deleted file mode 100644 index 1366148d9a3d..000000000000 --- a/src/transformers/models/xglm/modeling_flax_xglm.py +++ /dev/null @@ -1,803 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax XGLM model.""" - -import math -import random -from functools import partial -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax -from jax.random import PRNGKey - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, -) -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_xglm import XGLMConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/xglm-564M" -_CONFIG_FOR_DOC = "XGLMConfig" - -XGLM_START_DOCSTRING = r""" - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`XGLMConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - -XGLM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -def create_sinusoidal_positions(n_pos, dim, padding_idx=1): - half_dim = dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = np.exp(np.arange(half_dim) * -emb) - emb = np.expand_dims(np.arange(n_pos), 1) * np.expand_dims(emb, 0) - emb = np.concatenate([np.sin(emb), np.cos(emb)], 1) - emb = np.reshape(emb, (n_pos, dim)) - - if padding_idx is not None: - emb[padding_idx, :] = 0 - - return jnp.array(emb) - - -class FlaxXGLMAttention(nn.Module): - config: XGLMConfig - embed_dim: int - num_heads: int - dropout: float = 0.0 - causal: bool = False - bias: bool = True - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self) -> None: - self.head_dim = self.embed_dim // self.num_heads - - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} " - f"and `num_heads`: {self.num_heads})." - ) - - dense = partial( - nn.Dense, - self.embed_dim, - use_bias=self.bias, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() - self.out_proj = dense() - - self.dropout_layer = nn.Dropout(rate=self.dropout) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend - # to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states: jnp.ndarray, - key_value_states: Optional[jnp.ndarray] = None, - attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states) - value_states = self.v_proj(key_value_states) - else: - # self_attention - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = self._merge_heads(attn_output) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class FlaxXGLMDecoderLayer(nn.Module): - config: XGLMConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxXGLMAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.attention_heads, - dropout=self.config.attention_dropout, - causal=True, - dtype=self.dtype, - ) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - - if self.config.add_cross_attention: - self.encoder_attn = FlaxXGLMAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - self.fc1 = nn.Dense( - self.config.ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - # Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer.__call__ - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = True, - deterministic: bool = True, - ) -> tuple[jnp.ndarray]: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - ) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs - - -class FlaxXGLMDecoderLayerCollection(nn.Module): - config: XGLMConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxXGLMDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_layers) - ] - self.layerdrop = self.config.layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): - layer_outputs = (None, None, None) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - deterministic=deterministic, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_self_attns, all_cross_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class FlaxXGLMModule(nn.Module): - config: XGLMConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_target_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 - - self.embed_tokens = nn.Embed( - self.config.vocab_size, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - ) - - # XGLM is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - self.embed_positions = create_sinusoidal_positions( - self.config.max_position_embeddings + self.offset, embed_dim - ) - self.layers = FlaxXGLMDecoderLayerCollection(self.config, self.dtype) - self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - # embed positions - position_ids = position_ids + self.offset - positions = jnp.take(self.embed_positions, position_ids, axis=0) - - hidden_states = inputs_embeds + positions - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_states = outputs[0] - last_hidden_states = self.layer_norm(last_hidden_states) - - hidden_states = None - if output_hidden_states: - hidden_states = outputs[1] - hidden_states = hidden_states[:-1] + (last_hidden_states,) - - if not return_dict: - outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=last_hidden_states, - hidden_states=hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel): - config_class = XGLMConfig - base_model_prefix: str = "model" - module_class: nn.Module = None - - def __init__( - self, - config: XGLMConfig, - input_shape: tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) - - random_params = module_init_outputs["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length), dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: Optional[dict] = None, - past_key_values: Optional[dict] = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if encoder_hidden_states is not None and encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed - # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be - # changed by FlaxXGLMAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - return outputs - - -@add_start_docstrings( - "The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.", - XGLM_START_DOCSTRING, -) -class FlaxXGLMModel(FlaxXGLMPreTrainedModel): - module_class = FlaxXGLMModule - - -append_call_sample_docstring( - FlaxXGLMModel, - _CHECKPOINT_FOR_DOC, - FlaxBaseModelOutputWithPastAndCrossAttentions, - _CONFIG_FOR_DOC, -) - - -class FlaxXGLMForCausalLMModule(nn.Module): - config: XGLMConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.model = FlaxXGLMModule(self.config, self.dtype) - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids, - attention_mask, - position_ids, - encoder_hidden_states, - encoder_attention_mask, - deterministic=deterministic, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = self.model.variables["params"]["embed_tokens"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + outputs[1:] - - return FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - """ - The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - XGLM_START_DOCSTRING, -) -class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel): - module_class = FlaxXGLMForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since GPT2 uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxXGLMForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutputWithCrossAttentions, - _CONFIG_FOR_DOC, -) - - -__all__ = ["FlaxXGLMForCausalLM", "FlaxXGLMModel", "FlaxXGLMPreTrainedModel"] diff --git a/src/transformers/models/xglm/modeling_tf_xglm.py b/src/transformers/models/xglm/modeling_tf_xglm.py deleted file mode 100644 index d799ced79208..000000000000 --- a/src/transformers/models/xglm/modeling_tf_xglm.py +++ /dev/null @@ -1,1002 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 XGLM model.""" - -from __future__ import annotations - -import math -import random -from typing import Any - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation - -# Public API -from ...file_utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) -from ...modeling_tf_outputs import TFBaseModelOutputWithPastAndCrossAttentions, TFCausalLMOutputWithCrossAttentions -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - TFSharedEmbeddings, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import logging -from .configuration_xglm import XGLMConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/xglm-564M" -_CONFIG_FOR_DOC = "XGLMConfig" - - -LARGE_NEGATIVE = -1e8 - - -def create_sinusoidal_positions(num_positions: int, embedding_dim: int, padding_idx: int | None) -> tf.Tensor: - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = tf.exp(tf.range(half_dim, dtype=tf.float32) * -emb) - emb = tf.expand_dims(tf.range(num_positions, dtype=tf.float32), axis=1) * tf.expand_dims(emb, axis=0) - emb = tf.reshape(tf.concat([tf.sin(emb), tf.cos(emb)], axis=1), (num_positions, -1)) - if embedding_dim % 2 == 1: - # zero pad - emb = tf.concat([emb, tf.zeros((num_positions, 1))], axis=1) - if padding_idx is not None: - _padding_mask = tf.concat( - [ - tf.ones((padding_idx, shape_list(emb)[1])), - tf.zeros((1, shape_list(emb)[1])), - tf.ones((shape_list(emb)[0] - padding_idx - 1, shape_list(emb)[1])), - ], - axis=0, - ) - emb *= _padding_mask - - return tf.constant(emb, name="embed_positions") - - -def _create_position_ids_from_input_ids( - input_ids: tf.Tensor, past_key_values_length: int, padding_idx: int | None -) -> tf.Tensor: - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - """ - # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. - mask = tf.where(input_ids != padding_idx, 1, 0) - incremental_indices = (tf.cast(tf.cumsum(mask, axis=1), dtype=mask.dtype) + past_key_values_length) * mask - return tf.cast(incremental_indices, dtype=tf.int64) + padding_idx - - -def _create_position_ids_from_inputs_embeds( - inputs_embeds: tf.Tensor, past_key_values_length: int, padding_idx: int | None -) -> tf.Tensor: - """ - Args: - We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. - inputs_embeds: tf.Tensor - Returns: tf.Tensor - """ - input_shape = shape_list(inputs_embeds)[:-1] - sequence_length = input_shape[1] - - position_ids = tf.range(padding_idx + 1, sequence_length + padding_idx + 1, dtype=tf.int64) - - return tf.broadcast_to(tf.expand_dims(position_ids, axis=0), input_shape) + past_key_values_length - - -# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz = input_ids_shape[0] - tgt_len = input_ids_shape[1] - mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE - mask_cond = tf.range(shape_list(mask)[-1]) - - mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - - if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) - - return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) - - -# Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - src_len = shape_list(mask)[1] - tgt_len = tgt_len if tgt_len is not None else src_len - one_cst = tf.constant(1.0) - mask = tf.cast(mask, dtype=one_cst.dtype) - expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - - return (one_cst - expanded_mask) * LARGE_NEGATIVE - - -# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->XGLM -class TFXGLMAttention(keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): - return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) - - def call( - self, - hidden_states: tf.Tensor, - key_value_states: tf.Tensor | None = None, - past_key_value: tuple[tuple[tf.Tensor]] | None = None, - attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = shape_list(hidden_states) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = tf.concat([past_key_value[0], key_states], axis=2) - value_states = tf.concat([past_key_value[1], value_states], axis=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) - key_states = tf.reshape(key_states, proj_shape) - value_states = tf.reshape(value_states, proj_shape) - - src_len = shape_list(key_states)[1] - attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {shape_list(attn_weights)}" - ), - ) - - if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {shape_list(attention_mask)}" - ), - ) - - attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = stable_softmax(attn_weights, axis=-1) - - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=( - f"Head mask for a single layer should be of size {(self.num_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - - attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( - attn_weights, (bsz, self.num_heads, tgt_len, src_len) - ) - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {shape_list(attn_output)}" - ), - ) - - attn_output = tf.transpose( - tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) - ) - attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) - - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - - return attn_output, attn_weights, past_key_value - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "k_proj", None) is not None: - with tf.name_scope(self.k_proj.name): - self.k_proj.build([None, None, self.embed_dim]) - if getattr(self, "q_proj", None) is not None: - with tf.name_scope(self.q_proj.name): - self.q_proj.build([None, None, self.embed_dim]) - if getattr(self, "v_proj", None) is not None: - with tf.name_scope(self.v_proj.name): - self.v_proj.build([None, None, self.embed_dim]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.embed_dim]) - - -class TFXGLMDecoderLayer(keras.layers.Layer): - def __init__(self, config: XGLMConfig, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFXGLMAttention( - embed_dim=self.embed_dim, - num_heads=config.attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - name="self_attn", - ) - self.dropout = keras.layers.Dropout(config.dropout) - self.activation_fn = get_tf_activation(config.activation_function) - self.activation_dropout = keras.layers.Dropout(config.activation_dropout) - - if config.add_cross_attention: - self.encoder_attn = TFXGLMAttention( - embed_dim=self.embed_dim, - num_heads=config.attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - name="encoder_attn", - ) - self.encoder_attn_layer_norm = keras.layers.LayerNormalization( - epsilon=1e-5, name="encoder_attn_layer_norm" - ) - - self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.fc1 = keras.layers.Dense(config.ffn_dim, name="fc1") - self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - self.config = config - - # Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer.call - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor | None = None, - encoder_hidden_states: tf.Tensor | None = None, - encoder_attention_mask: tf.Tensor | None = None, - layer_head_mask: tf.Tensor | None = None, - cross_attn_layer_head_mask: tf.Tensor | None = None, - past_key_value: tuple[tf.Tensor] | None = None, - training: bool | None = False, - ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: - """ - Args: - hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* - attention_mask (`tf.Tensor`): attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - encoder_hidden_states (`tf.Tensor`): - cross attention input to the layer of shape *(batch, seq_len, embed_dim)* - encoder_attention_mask (`tf.Tensor`): encoder attention mask of size - *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size - *(decoder_attention_heads,)* - cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. - *(decoder_attention_heads,)* - past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - ) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout(hidden_states, training=training) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = residual + hidden_states - - return ( - hidden_states, - self_attn_weights, - cross_attn_weights, - present_key_value, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attn", None) is not None: - with tf.name_scope(self.self_attn.name): - self.self_attn.build(None) - if getattr(self, "self_attn_layer_norm", None) is not None: - with tf.name_scope(self.self_attn_layer_norm.name): - self.self_attn_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "fc1", None) is not None: - with tf.name_scope(self.fc1.name): - self.fc1.build([None, None, self.embed_dim]) - if getattr(self, "fc2", None) is not None: - with tf.name_scope(self.fc2.name): - self.fc2.build([None, None, self.config.ffn_dim]) - if getattr(self, "final_layer_norm", None) is not None: - with tf.name_scope(self.final_layer_norm.name): - self.final_layer_norm.build([None, None, self.embed_dim]) - if getattr(self, "encoder_attn", None) is not None: - with tf.name_scope(self.encoder_attn.name): - self.encoder_attn.build(None) - if getattr(self, "encoder_attn_layer_norm", None) is not None: - with tf.name_scope(self.encoder_attn_layer_norm.name): - self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) - - -@keras_serializable -class TFXGLMMainLayer(keras.layers.Layer): - config_class = XGLMConfig - - def __init__( - self, config: XGLMConfig, embed_tokens: TFSharedEmbeddings | None = None, *inputs, **kwargs: Any - ) -> None: - super().__init__(*inputs, **kwargs) - - self.config = config - self.padding_idx = config.pad_token_id - self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = TFSharedEmbeddings( - config.vocab_size, config.d_model, self.padding_idx, name="embed_tokens" - ) - - self.offset = 2 - self._embed_positions_weights = create_sinusoidal_positions( - num_positions=config.max_position_embeddings + self.offset, - embedding_dim=config.d_model, - padding_idx=config.pad_token_id, - ) - - self.dropout = keras.layers.Dropout(config.dropout) - self.layers = [TFXGLMDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_layers)] - self.layerdrop = config.layerdrop - self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") - - def get_input_embeddings(self) -> TFSharedEmbeddings: - return self.embed_tokens - - def set_input_embeddings(self, value: TFSharedEmbeddings) -> None: - self.embed_tokens = value - - def _prepare_decoder_attention_mask( - self, - attention_mask: tf.Tensor | None, - input_shape: tf.TensorShape, - past_key_values_length: int, - ) -> tf.Tensor: - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length) - combined_attention_mask = tf.cond( - input_shape[-1] > 1, lambda: combined_attention_mask, lambda: tf.ones_like(combined_attention_mask) - ) - if attention_mask is None: - return combined_attention_mask - expand_attention_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1]) - return expand_attention_mask + combined_attention_mask - - def embed_positions(self, position_ids: np.ndarray | tf.Tensor | None = None) -> tf.Tensor: - position_ids += self.offset - positions = tf.gather(self._embed_positions_weights, position_ids, axis=0) - return positions - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs: Any, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = tf.shape(input_ids) - input_ids = tf.reshape(input_ids, (-1, input_shape[-1])) - elif inputs_embeds is not None: - input_shape = tf.shape(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if position_ids is None: - position_ids = tf.expand_dims( - tf.range(past_key_values_length, input_shape[-1] + past_key_values_length), axis=0 - ) - position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embed_tokens.vocab_size) - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) - - # embed positions - positions = self.embed_positions(position_ids) - - hidden_states = tf.cast(inputs_embeds, dtype=tf.float32) + positions - - hidden_states = self.dropout(hidden_states, training=training) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None - - # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired - for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: - if attn_mask is not None: - tf.debugging.assert_equal( - shape_list(attn_mask)[0], - len(self.layers), - message=( - f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" - f" {shape_list(attn_mask)[0]}." - ), - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - - dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, - ) - - if use_cache: - next_decoder_cache += (present_key_value,) - - if output_attentions: - all_self_attns += (layer_self_attn,) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_cross_attn,) - - hidden_states = self.layer_norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] - if v is not None - ) - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "embed_tokens", None) is not None: - with tf.name_scope(self.embed_tokens.name): - self.embed_tokens.build(None) - if getattr(self, "layers", None) is not None: - for layer in self.layers: - with tf.name_scope(layer.name): - layer.build(None) - - -class TFXGLMPreTrainedModel(TFPreTrainedModel): - config_class = XGLMConfig - base_model_prefix = "model" - - -XGLM_START_DOCSTRING = r""" - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Args: - config ([`XGLMConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -XGLM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of - the decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`tf.Tensor` of shape `(num_layers, attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`tf.Tensor` of shape `(num_layers, attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.num_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.", - XGLM_START_DOCSTRING, -) -class TFXGLMModel(TFXGLMPreTrainedModel): - """ - Transformer decoder consisting of *config.num_layers* layers. Each layer is a [`TFXGLMDecoderLayer`] - - Args: - config: XGLMConfig - embed_tokens: [TFSharedEmbeddings]: output embedding - """ - - def __init__( - self, config: XGLMConfig, embed_tokens: TFSharedEmbeddings | None = None, *inputs: Any, **kwargs: Any - ) -> None: - super().__init__(config, *inputs, **kwargs) - - self.model = TFXGLMMainLayer(config, embed_tokens=embed_tokens, name="model") - - @unpack_inputs - @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPastAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs: Any, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - - -@add_start_docstrings( - """ - The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - XGLM_START_DOCSTRING, -) -class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss): - base_model_prefix = "model" - _keys_to_ignore_on_load_missing = [ - r"model.embed_positions.weights", - r"lm_head.weight", - ] - _keys_to_ignore_on_save = [ - r"model.embed_positions.weights", - ] - - def __init__( - self, config: XGLMConfig, embed_tokens: TFSharedEmbeddings | None = None, *inputs: Any, **kwargs: Any - ) -> None: - super().__init__(config, *inputs, **kwargs) - - self.model = TFXGLMMainLayer(config, embed_tokens=embed_tokens, name="model") - self.lm_head = keras.layers.Dense( - config.vocab_size, - use_bias=False, - kernel_initializer=get_initializer(config.init_std), - name="lm_head", - ) - self.config = config - - def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs): - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - inputs = tf.expand_dims(inputs[:, -1], -1) - - position_ids = kwargs.get("position_ids") - attention_mask = kwargs.get("attention_mask") - - if attention_mask is not None and position_ids is None: - position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) - if past_key_values: - position_ids = tf.expand_dims(position_ids[:, -1], -1) - - return { - "input_ids": inputs, - "attention_mask": attention_mask, - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - } - - @unpack_inputs - @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - labels: np.ndarray | tf.Tensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - **kwargs: Any, - ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]: - r""" - labels (`np.ndarray` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = outputs[0] - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # shift labels to the left and cut last logit token - labels = tf.concat( - [labels[:, 1:], tf.fill((labels.shape[0], 1), tf.cast(-100, labels.dtype))], - axis=-1, - ) - loss = self.hf_compute_loss(labels, lm_logits) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "model", None) is not None: - with tf.name_scope(self.model.name): - self.model.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build([None, None, self.config.hidden_size]) - - def tf_to_pt_weight_rename(self, tf_weight): - if tf_weight == "lm_head.weight": - return tf_weight, "model.embed_tokens.weight" - else: - return (tf_weight,) - - -__all__ = ["TFXGLMForCausalLM", "TFXGLMModel", "TFXGLMPreTrainedModel"] diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py deleted file mode 100644 index db89b4686f84..000000000000 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ /dev/null @@ -1,1356 +0,0 @@ -# coding=utf-8 -# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -TF 2.0 XLM model. -""" - -from __future__ import annotations - -import itertools -import warnings -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFSequenceSummary, - TFSharedEmbeddings, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - MULTIPLE_CHOICE_DUMMY_INPUTS, - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_xlm import XLMConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "FacebookAI/xlm-mlm-en-2048" -_CONFIG_FOR_DOC = "XLMConfig" - - -def create_sinusoidal_embeddings(n_pos, dim, out): - position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) - out[:, 0::2] = tf.constant(np.sin(position_enc[:, 0::2])) - out[:, 1::2] = tf.constant(np.cos(position_enc[:, 1::2])) - - -def get_masks(slen, lengths, causal, padding_mask=None): - """ - Generate hidden states mask, and optionally an attention mask. - """ - bs = shape_list(lengths)[0] - if padding_mask is not None: - mask = padding_mask - else: - # assert lengths.max().item() <= slen - alen = tf.range(slen, dtype=lengths.dtype) - mask = alen < tf.expand_dims(lengths, axis=1) - - # attention mask is the same as mask, or triangular inferior attention (causal) - if causal: - attn_mask = tf.less_equal( - tf.tile(tf.reshape(alen, (1, 1, slen)), (bs, slen, 1)), tf.reshape(alen, (1, slen, 1)) - ) - else: - attn_mask = mask - - # sanity check - # assert shape_list(mask) == [bs, slen] - tf.debugging.assert_equal(shape_list(mask), [bs, slen]) - if causal: - tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen]) - - return mask, attn_mask - - -class TFXLMMultiHeadAttention(keras.layers.Layer): - NEW_ID = itertools.count() - - def __init__(self, n_heads, dim, config, **kwargs): - super().__init__(**kwargs) - self.layer_id = next(TFXLMMultiHeadAttention.NEW_ID) - self.dim = dim - self.n_heads = n_heads - self.output_attentions = config.output_attentions - assert self.dim % self.n_heads == 0 - - self.q_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin") - self.k_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="k_lin") - self.v_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="v_lin") - self.out_lin = keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="out_lin") - self.dropout = keras.layers.Dropout(config.attention_dropout) - self.pruned_heads = set() - self.dim = dim - - def prune_heads(self, heads): - raise NotImplementedError - - def call(self, input, mask, kv, cache, head_mask, output_attentions, training=False): - """ - Self-attention (if kv is None) or attention over source sentence (provided by kv). - """ - # Input is (bs, qlen, dim) - # Mask is (bs, klen) (non-causal) or (bs, klen, klen) - bs, qlen, dim = shape_list(input) - - if kv is None: - klen = qlen if cache is None else cache["slen"] + qlen - else: - klen = shape_list(kv)[1] - - # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' - dim_per_head = self.dim // self.n_heads - mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen) - - def shape(x): - """projection""" - return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3)) - - def unshape(x): - """compute context""" - return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head)) - - q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) - - if kv is None: - k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) - elif cache is None or self.layer_id not in cache: - k = v = kv - k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head) - - if cache is not None: - if self.layer_id in cache: - if kv is None: - k_, v_ = cache[self.layer_id] - k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head) - v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head) - else: - k, v = cache[self.layer_id] - - cache[self.layer_id] = (k, v) - - f_dim_per_head = tf.cast(dim_per_head, dtype=q.dtype) - q = tf.multiply(q, tf.math.rsqrt(f_dim_per_head)) # (bs, n_heads, qlen, dim_per_head) - k = tf.cast(k, dtype=q.dtype) - scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen) - mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) - # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen) - mask = tf.cast(mask, dtype=scores.dtype) - scores = scores - 1e30 * (1.0 - mask) - weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) - weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) - - # Mask heads if we want to - if head_mask is not None: - weights = weights * head_mask - - context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) - context = unshape(context) # (bs, qlen, dim) - outputs = (self.out_lin(context),) - - if output_attentions: - outputs = outputs + (weights,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "q_lin", None) is not None: - with tf.name_scope(self.q_lin.name): - self.q_lin.build([None, None, self.dim]) - if getattr(self, "k_lin", None) is not None: - with tf.name_scope(self.k_lin.name): - self.k_lin.build([None, None, self.dim]) - if getattr(self, "v_lin", None) is not None: - with tf.name_scope(self.v_lin.name): - self.v_lin.build([None, None, self.dim]) - if getattr(self, "out_lin", None) is not None: - with tf.name_scope(self.out_lin.name): - self.out_lin.build([None, None, self.dim]) - - -class TFXLMTransformerFFN(keras.layers.Layer): - def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs): - super().__init__(**kwargs) - - self.lin1 = keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1") - self.lin2 = keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2") - self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu") - self.dropout = keras.layers.Dropout(config.dropout) - self.in_dim = in_dim - self.dim_hidden = dim_hidden - - def call(self, input, training=False): - x = self.lin1(input) - x = self.act(x) - x = self.lin2(x) - x = self.dropout(x, training=training) - - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "lin1", None) is not None: - with tf.name_scope(self.lin1.name): - self.lin1.build([None, None, self.in_dim]) - if getattr(self, "lin2", None) is not None: - with tf.name_scope(self.lin2.name): - self.lin2.build([None, None, self.dim_hidden]) - - -@keras_serializable -class TFXLMMainLayer(keras.layers.Layer): - config_class = XLMConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - self.return_dict = config.use_return_dict - - # encoder / decoder, output layer - self.is_encoder = config.is_encoder - self.is_decoder = not config.is_encoder - - if self.is_decoder: - raise NotImplementedError("Currently XLM can only be used as an encoder") - - # self.with_output = with_output - self.causal = config.causal - - # dictionary / languages - self.n_langs = config.n_langs - self.use_lang_emb = config.use_lang_emb - self.n_words = config.n_words - self.eos_index = config.eos_index - self.pad_index = config.pad_index - # self.dico = dico - # self.id2lang = config.id2lang - # self.lang2id = config.lang2id - # assert len(self.dico) == self.n_words - # assert len(self.id2lang) == len(self.lang2id) == self.n_langs - - # model parameters - self.dim = config.emb_dim # 512 by default - self.hidden_dim = self.dim * 4 # 2048 by default - self.n_heads = config.n_heads # 8 by default - self.n_layers = config.n_layers - self.max_position_embeddings = config.max_position_embeddings - self.embed_init_std = config.embed_init_std - if self.dim % self.n_heads != 0: - raise ValueError("transformer dim must be a multiple of n_heads") - - # embeddings - self.dropout = keras.layers.Dropout(config.dropout) - self.attention_dropout = keras.layers.Dropout(config.attention_dropout) - - if config.sinusoidal_embeddings: - raise NotImplementedError - # create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) - - self.embeddings = TFSharedEmbeddings( - self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings" - ) # padding_idx=self.pad_index) - self.layer_norm_emb = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm_emb") - - # transformer layers - self.attentions = [] - self.layer_norm1 = [] - self.ffns = [] - self.layer_norm2 = [] - # if self.is_decoder: - # self.layer_norm15 = [] - # self.encoder_attn = [] - - for i in range(self.n_layers): - self.attentions.append( - TFXLMMultiHeadAttention(self.n_heads, self.dim, config=config, name=f"attentions_._{i}") - ) - self.layer_norm1.append( - keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name=f"layer_norm1_._{i}") - ) - # if self.is_decoder: - # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) - # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) - self.ffns.append( - TFXLMTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name=f"ffns_._{i}") - ) - self.layer_norm2.append( - keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name=f"layer_norm2_._{i}") - ) - - if hasattr(config, "pruned_heads"): - pruned_heads = config.pruned_heads.copy().items() - config.pruned_heads = {} - - for layer, heads in pruned_heads: - if self.attentions[int(layer)].n_heads == config.n_heads: - self.prune_heads({int(layer): list(map(int, heads))}) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.dim], - initializer=get_initializer(self.embed_init_std), - ) - - if self.n_langs > 1 and self.use_lang_emb: - with tf.name_scope("lang_embeddings"): - self.lang_embeddings = self.add_weight( - name="embeddings", - shape=[self.n_langs, self.dim], - initializer=get_initializer(self.embed_init_std), - ) - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - if getattr(self, "layer_norm_emb", None) is not None: - with tf.name_scope(self.layer_norm_emb.name): - self.layer_norm_emb.build([None, None, self.dim]) - for layer in self.attentions: - with tf.name_scope(layer.name): - layer.build(None) - for layer in self.layer_norm1: - with tf.name_scope(layer.name): - layer.build([None, None, self.dim]) - for layer in self.ffns: - with tf.name_scope(layer.name): - layer.build(None) - for layer in self.layer_norm2: - with tf.name_scope(layer.name): - layer.build([None, None, self.dim]) - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, value): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - langs=None, - token_type_ids=None, - position_ids=None, - lengths=None, - cache=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - # removed: src_enc=None, src_len=None - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - bs, slen = shape_list(input_ids) - elif inputs_embeds is not None: - bs, slen = shape_list(inputs_embeds)[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if lengths is None: - if input_ids is not None: - lengths = tf.reduce_sum( - tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=input_ids.dtype), axis=1 - ) - else: - lengths = tf.convert_to_tensor([slen] * bs) - # mask = input_ids != self.pad_index - - # check inputs - # assert shape_list(lengths)[0] == bs - ( - tf.debugging.assert_equal(shape_list(lengths)[0], bs), - f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched", - ) - # assert lengths.max().item() <= slen - # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 - # assert (src_enc is None) == (src_len is None) - # if src_enc is not None: - # assert self.is_decoder - # assert src_enc.size(0) == bs - - # generate masks - mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) - # if self.is_decoder and src_enc is not None: - # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] - - # position_ids - if position_ids is None: - position_ids = tf.expand_dims(tf.range(slen), axis=0) - position_ids = tf.tile(position_ids, (bs, 1)) - - # assert shape_list(position_ids) == [bs, slen] # (slen, bs) - ( - tf.debugging.assert_equal(shape_list(position_ids), [bs, slen]), - f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched", - ) - # position_ids = position_ids.transpose(0, 1) - - # langs - if langs is not None: - # assert shape_list(langs) == [bs, slen] # (slen, bs) - ( - tf.debugging.assert_equal(shape_list(langs), [bs, slen]), - f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched", - ) - # langs = langs.transpose(0, 1) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.n_layers - - # do not recompute cached elements - if cache is not None and input_ids is not None: - _slen = slen - cache["slen"] - input_ids = input_ids[:, -_slen:] - position_ids = position_ids[:, -_slen:] - if langs is not None: - langs = langs[:, -_slen:] - mask = mask[:, -_slen:] - attn_mask = attn_mask[:, -_slen:] - - # embeddings - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.embeddings.vocab_size) - inputs_embeds = self.embeddings(input_ids) - - tensor = inputs_embeds + tf.gather(self.position_embeddings, position_ids) - - if langs is not None and self.use_lang_emb and self.n_langs > 1: - tensor = tensor + tf.gather(self.lang_embeddings, langs) - if token_type_ids is not None: - tensor = tensor + self.embeddings(token_type_ids) - - tensor = self.layer_norm_emb(tensor) - tensor = self.dropout(tensor, training=training) - mask = tf.cast(mask, dtype=tensor.dtype) - tensor = tensor * tf.expand_dims(mask, axis=-1) - - # transformer layers - hidden_states = () if output_hidden_states else None - attentions = () if output_attentions else None - - for i in range(self.n_layers): - if output_hidden_states: - hidden_states = hidden_states + (tensor,) - - # self attention - attn_outputs = self.attentions[i]( - tensor, - attn_mask, - None, - cache, - head_mask[i], - output_attentions, - training=training, - ) - attn = attn_outputs[0] - - if output_attentions: - attentions = attentions + (attn_outputs[1],) - - attn = self.dropout(attn, training=training) - tensor = tensor + attn - tensor = self.layer_norm1[i](tensor) - - # encoder attention (for decoder only) - # if self.is_decoder and src_enc is not None: - # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) - # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training) - # tensor = tensor + attn - # tensor = self.layer_norm15[i](tensor) - - # FFN - tensor = tensor + self.ffns[i](tensor) - tensor = self.layer_norm2[i](tensor) - tensor = tensor * tf.expand_dims(mask, axis=-1) - - # Add last hidden state - if output_hidden_states: - hidden_states = hidden_states + (tensor,) - - # update cache length - if cache is not None: - cache["slen"] += tensor.size(1) - - # move back sequence length to dimension 0 - # tensor = tensor.transpose(0, 1) - - if not return_dict: - return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) - - return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) - - -class TFXLMPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = XLMConfig - base_model_prefix = "transformer" - - @property - def dummy_inputs(self): - # Sometimes XLM has language embeddings so don't forget to build them as well if needed - inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]], dtype=tf.int32) - attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32) - if self.config.use_lang_emb and self.config.n_langs > 1: - return { - "input_ids": inputs_list, - "attention_mask": attns_list, - "langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32), - } - else: - return {"input_ids": inputs_list, "attention_mask": attns_list} - - -# Remove when XLMWithLMHead computes loss like other LM models -@dataclass -class TFXLMWithLMHeadModelOutput(ModelOutput): - """ - Base class for [`TFXLMWithLMHeadModel`] outputs. - - Args: - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: tf.Tensor | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -XLM_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`XLMConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -XLM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - langs (`tf.Tensor` or `Numpy array` of shape `({0})`, *optional*): - A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are - languages ids which can be obtained from the language names by using two conversion mappings provided in - the configuration of the model (only provided for multilingual models). More precisely, the *language name - to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the - *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string). - - See usage examples detailed in the [multilingual documentation](../multilingual). - token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - lengths (`tf.Tensor` or `Numpy array` of shape `(batch_size,)`, *optional*): - Length of each sentence that can be used to avoid performing attention on padding token indices. You can - also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in - `[0, ..., input_ids.size(-1)]`. - cache (`dict[str, tf.Tensor]`, *optional*): - Dictionary string to `tf.Tensor` that contains precomputed hidden states (key and values in the attention - blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare XLM Model transformer outputting raw hidden-states without any specific head on top.", - XLM_START_DOCSTRING, -) -class TFXLMModel(TFXLMPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFXLMMainLayer(config, name="transformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: tf.Tensor | None = None, - langs: tf.Tensor | None = None, - token_type_ids: tf.Tensor | None = None, - position_ids: tf.Tensor | None = None, - lengths: tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: tf.Tensor | None = None, - inputs_embeds: tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutput | tuple[tf.Tensor]: - outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - langs=langs, - token_type_ids=token_type_ids, - position_ids=position_ids, - lengths=lengths, - cache=cache, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -class TFXLMPredLayer(keras.layers.Layer): - """ - Prediction layer (cross_entropy or adaptive_softmax). - """ - - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - - self.asm = config.asm - self.n_words = config.n_words - self.pad_index = config.pad_index - - if config.asm is False: - self.input_embeddings = input_embeddings - else: - raise NotImplementedError - # self.proj = nn.AdaptiveLogSoftmaxWithLoss( - # in_features=dim, - # n_classes=config.n_words, - # cutoffs=config.asm_cutoffs, - # div_value=config.asm_div_value, - # head_bias=True, # default is False - # ) - - def build(self, input_shape): - # The output weights are the same as the input embeddings, but there is an output-only bias for each token. - self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias") - - super().build(input_shape) - - def get_output_embeddings(self): - return self.input_embeddings - - def set_output_embeddings(self, value): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - hidden_states = self.input_embeddings(hidden_states, mode="linear") - hidden_states = hidden_states + self.bias - - return hidden_states - - -@add_start_docstrings( - """ - The XLM Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - XLM_START_DOCSTRING, -) -class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFXLMMainLayer(config, name="transformer") - self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj") - # XLM does not have past caching features - self.supports_xla_generation = False - - def get_lm_head(self): - return self.pred_layer - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.pred_layer.name - - def prepare_inputs_for_generation(self, inputs, **kwargs): - mask_token_id = self.config.mask_token_id - lang_id = self.config.lang_id - - effective_batch_size = inputs.shape[0] - mask_token = tf.fill((effective_batch_size, 1), 1) * mask_token_id - inputs = tf.concat([inputs, mask_token], axis=1) - - if lang_id is not None: - langs = tf.ones_like(inputs) * lang_id - else: - langs = None - return {"input_ids": inputs, "langs": langs} - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFXLMWithLMHeadModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFXLMWithLMHeadModelOutput | tuple[tf.Tensor]: - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - langs=langs, - token_type_ids=token_type_ids, - position_ids=position_ids, - lengths=lengths, - cache=cache, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - output = transformer_outputs[0] - outputs = self.pred_layer(output) - - if not return_dict: - return (outputs,) + transformer_outputs[1:] - - return TFXLMWithLMHeadModelOutput( - logits=outputs, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "pred_layer", None) is not None: - with tf.name_scope(self.pred_layer.name): - self.pred_layer.build(None) - - -@add_start_docstrings( - """ - XLM Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. - for GLUE tasks. - """, - XLM_START_DOCSTRING, -) -class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.transformer = TFXLMMainLayer(config, name="transformer") - self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary") - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - langs=langs, - token_type_ids=token_type_ids, - position_ids=position_ids, - lengths=lengths, - cache=cache, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - output = transformer_outputs[0] - - logits = self.sequence_summary(output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "sequence_summary", None) is not None: - with tf.name_scope(self.sequence_summary.name): - self.sequence_summary.build(None) - - -@add_start_docstrings( - """ - XLM Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - XLM_START_DOCSTRING, -) -class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.transformer = TFXLMMainLayer(config, name="transformer") - self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary") - self.logits_proj = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj" - ) - self.config = config - - @property - def dummy_inputs(self): - """ - Dummy inputs to build the network. - - Returns: - tf.Tensor with dummy inputs - """ - # Sometimes XLM has language embeddings so don't forget to build them as well if needed - if self.config.use_lang_emb and self.config.n_langs > 1: - return { - "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32), - "langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32), - } - else: - return { - "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32), - } - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None - flat_inputs_embeds = ( - tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - - if lengths is not None: - logger.warning( - "The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the " - "attention mask instead.", - ) - lengths = None - - transformer_outputs = self.transformer( - flat_input_ids, - flat_attention_mask, - flat_langs, - flat_token_type_ids, - flat_position_ids, - lengths, - cache, - head_mask, - flat_inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - output = transformer_outputs[0] - logits = self.sequence_summary(output) - logits = self.logits_proj(logits) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "sequence_summary", None) is not None: - with tf.name_scope(self.sequence_summary.name): - self.sequence_summary.build(None) - if getattr(self, "logits_proj", None) is not None: - with tf.name_scope(self.logits_proj.name): - self.logits_proj.build([None, None, self.config.num_labels]) - - -@add_start_docstrings( - """ - XLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - XLM_START_DOCSTRING, -) -class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.transformer = TFXLMMainLayer(config, name="transformer") - self.dropout = keras.layers.Dropout(config.dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - langs=langs, - token_type_ids=token_type_ids, - position_ids=position_ids, - lengths=lengths, - cache=cache, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = transformer_outputs[0] - - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer - on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - XLM_START_DOCSTRING, -) -class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFXLMMainLayer(config, name="transformer") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.init_std), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - langs: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - lengths: np.ndarray | tf.Tensor | None = None, - cache: dict[str, tf.Tensor] | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - langs=langs, - token_type_ids=token_type_ids, - position_ids=position_ids, - lengths=lengths, - cache=cache, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = transformer_outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - - loss = None - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFXLMForMultipleChoice", - "TFXLMForQuestionAnsweringSimple", - "TFXLMForSequenceClassification", - "TFXLMForTokenClassification", - "TFXLMMainLayer", - "TFXLMModel", - "TFXLMPreTrainedModel", - "TFXLMWithLMHeadModel", -] diff --git a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py deleted file mode 100644 index bdbc06620a1b..000000000000 --- a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py +++ /dev/null @@ -1,1511 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Facebook AI Research and the HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax XLM-RoBERTa model.""" - -from typing import Callable, Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen import partitioning as nn_partitioning -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax import lax - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxBaseModelOutputWithPooling, - FlaxBaseModelOutputWithPoolingAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxMaskedLMOutput, - FlaxMultipleChoiceModelOutput, - FlaxQuestionAnsweringModelOutput, - FlaxSequenceClassifierOutput, - FlaxTokenClassifierOutput, -) -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from .configuration_xlm_roberta import XLMRobertaConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "FacebookAI/xlm-roberta-base" -_CONFIG_FOR_DOC = "XLMRobertaConfig" - -remat = nn_partitioning.remat - - -# Copied from transformers.models.roberta.modeling_flax_roberta.create_position_ids_from_input_ids -def create_position_ids_from_input_ids(input_ids, padding_idx): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - input_ids: jnp.ndarray - padding_idx: int - - Returns: jnp.ndarray - """ - # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. - mask = (input_ids != padding_idx).astype("i4") - - if mask.ndim > 2: - mask = mask.reshape((-1, mask.shape[-1])) - incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask - incremental_indices = incremental_indices.reshape(input_ids.shape) - else: - incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask - - return incremental_indices.astype("i4") + padding_idx - - -XLM_ROBERTA_START_DOCSTRING = r""" - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a - [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as - a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and - behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`XLMRobertaConfig`]): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. -""" - -XLM_ROBERTA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`numpy.ndarray` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`numpy.ndarray` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - head_mask (`numpy.ndarray` of shape `({0})`, `optional): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->XLMRoberta -class FlaxXLMRobertaEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" - - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.word_embeddings = nn.Embed( - self.config.vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.position_embeddings = nn.Embed( - self.config.max_position_embeddings, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.token_type_embeddings = nn.Embed( - self.config.type_vocab_size, - self.config.hidden_size, - embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): - # Embed - inputs_embeds = self.word_embeddings(input_ids.astype("i4")) - position_embeds = self.position_embeddings(position_ids.astype("i4")) - token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) - - # Sum all embeddings - hidden_states = inputs_embeds + token_type_embeddings + position_embeds - - # Layer Norm - hidden_states = self.LayerNorm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->XLMRoberta -class FlaxXLMRobertaSelfAttention(nn.Module): - config: XLMRobertaConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.head_dim = self.config.hidden_size // self.config.num_attention_heads - if self.config.hidden_size % self.config.num_attention_heads != 0: - raise ValueError( - "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " - " : {self.config.num_attention_heads}" - ) - - self.query = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.key = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.value = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - if self.causal: - self.causal_mask = make_causal_mask( - jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" - ) - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) - - @nn.compact - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slightly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = lax.dynamic_update_slice(cached_key.value, key, indices) - value = lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic=True, - output_attentions: bool = False, - ): - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size = hidden_states.shape[0] - - # get query proj - query_states = self.query(hidden_states) - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self.key(key_value_states) - value_states = self.value(key_value_states) - else: - # self_attention - key_states = self.key(hidden_states) - value_states = self.value(hidden_states) - - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # handle cache prepare causal attention mask - if self.causal: - query_length, key_length = query_states.shape[1], key_states.shape[1] - if self.has_variable("cache", "cached_key"): - mask_shift = self.variables["cache"]["cache_index"] - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_mask = lax.dynamic_slice( - self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) - ) - else: - causal_mask = self.causal_mask[:, :, :query_length, :key_length] - causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) - - # combine masks if needed - if attention_mask is not None and self.causal: - attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) - attention_mask = combine_masks(attention_mask, causal_mask) - elif self.causal: - attention_mask = causal_mask - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # Convert the boolean attention mask to an attention bias. - if attention_mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), - ) - else: - attention_bias = None - - dropout_rng = None - if not deterministic and self.config.attention_probs_dropout_prob > 0.0: - dropout_rng = self.make_rng("dropout") - - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.config.attention_probs_dropout_prob, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - precision=None, - ) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) - - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) - - outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->XLMRoberta -class FlaxXLMRobertaSelfOutput(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - - def __call__(self, hidden_states, input_tensor, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->XLMRoberta -class FlaxXLMRobertaAttention(nn.Module): - config: XLMRobertaConfig - causal: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.self = FlaxXLMRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype) - self.output = FlaxXLMRobertaSelfOutput(self.config, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - key_value_states=None, - init_cache=False, - deterministic=True, - output_attentions: bool = False, - ): - # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) - # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable - # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - attn_outputs = self.self( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=key_value_states, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_outputs[1],) - - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->XLMRoberta -class FlaxXLMRobertaIntermediate(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.activation = ACT2FN[self.config.hidden_act] - - def __call__(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->XLMRoberta -class FlaxXLMRobertaOutput(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - - def __call__(self, hidden_states, attention_output, deterministic: bool = True): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.LayerNorm(hidden_states + attention_output) - return hidden_states - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->XLMRoberta -class FlaxXLMRobertaLayer(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.attention = FlaxXLMRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) - self.intermediate = FlaxXLMRobertaIntermediate(self.config, dtype=self.dtype) - self.output = FlaxXLMRobertaOutput(self.config, dtype=self.dtype) - if self.config.add_cross_attention: - self.crossattention = FlaxXLMRobertaAttention(self.config, causal=False, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - ): - # Self Attention - attention_outputs = self.attention( - hidden_states, - attention_mask, - layer_head_mask=layer_head_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = attention_outputs[0] - - # Cross-Attention Block - if encoder_hidden_states is not None: - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, - key_value_states=encoder_hidden_states, - deterministic=deterministic, - output_attentions=output_attentions, - ) - attention_output = cross_attention_outputs[0] - - hidden_states = self.intermediate(attention_output) - hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[1],) - if encoder_hidden_states is not None: - outputs += (cross_attention_outputs[1],) - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->XLMRoberta -class FlaxXLMRobertaLayerCollection(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - if self.gradient_checkpointing: - FlaxXLMRobertaCheckpointLayer = remat(FlaxXLMRobertaLayer, static_argnums=(5, 6, 7)) - self.layers = [ - FlaxXLMRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - else: - self.layers = [ - FlaxXLMRobertaLayer(self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - - # Check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.shape[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for " - f" {head_mask.shape[0]}." - ) - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = layer( - hidden_states, - attention_mask, - head_mask[i] if head_mask is not None else None, - encoder_hidden_states, - encoder_attention_mask, - init_cache, - deterministic, - output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->XLMRoberta -class FlaxXLMRobertaEncoder(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.layer = FlaxXLMRobertaLayerCollection( - self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - def __call__( - self, - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - return self.layer( - hidden_states, - attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->XLMRoberta -class FlaxXLMRobertaPooler(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - cls_hidden_state = hidden_states[:, 0] - cls_hidden_state = self.dense(cls_hidden_state) - return nn.tanh(cls_hidden_state) - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaLMHead with Roberta->XLMRoberta -class FlaxXLMRobertaLMHead(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) - self.decoder = nn.Dense( - self.config.vocab_size, - dtype=self.dtype, - use_bias=False, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) - - def __call__(self, hidden_states, shared_embedding=None): - hidden_states = self.dense(hidden_states) - hidden_states = ACT2FN["gelu"](hidden_states) - hidden_states = self.layer_norm(hidden_states) - - if shared_embedding is not None: - hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - hidden_states = self.decoder(hidden_states) - - bias = jnp.asarray(self.bias, self.dtype) - hidden_states += bias - return hidden_states - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaClassificationHead with Roberta->XLMRoberta -class FlaxXLMRobertaClassificationHead(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.dense = nn.Dense( - self.config.hidden_size, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(rate=classifier_dropout) - self.out_proj = nn.Dense( - self.config.num_labels, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range), - ) - - def __call__(self, hidden_states, deterministic=True): - hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.dense(hidden_states) - hidden_states = nn.tanh(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaPreTrainedModel with Roberta->XLMRoberta, roberta->xlm-roberta, ROBERTA->XLM_ROBERTA -class FlaxXLMRobertaPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = XLMRobertaConfig - base_model_prefix = "xlm-roberta" - - module_class: nn.Module = None - - def __init__( - self, - config: XLMRobertaConfig, - input_shape: tuple = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - gradient_checkpointing: bool = False, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing - def enable_gradient_checkpointing(self): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=True, - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - token_type_ids = jnp.ones_like(input_ids) - position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) - attention_mask = jnp.ones_like(input_ids) - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - if self.config.add_cross_attention: - encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) - encoder_attention_mask = attention_mask - module_init_outputs = self.module.init( - rngs, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - return_dict=False, - ) - else: - module_init_outputs = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False - ) - - random_params = module_init_outputs["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache - def init_cache(self, batch_size, max_length): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - """ - # init input variables to retrieve cache - input_ids = jnp.ones((batch_size, max_length), dtype="i4") - attention_mask = jnp.ones_like(input_ids, dtype="i4") - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - init_variables = self.module.init( - jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - params: Optional[dict] = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - past_key_values: Optional[dict] = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # init input tensors if not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - if position_ids is None: - position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - if head_mask is None: - head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - if self.config.add_cross_attention: - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed - # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be - # changed by FlaxXLMRobertaAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - mutable=mutable, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past_key_values = outputs - outputs["past_key_values"] = unfreeze(past_key_values["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past_key_values = outputs - outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] - - else: - outputs = self.module.apply( - inputs, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - token_type_ids=jnp.array(token_type_ids, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - head_mask=jnp.array(head_mask, dtype="i4"), - deterministic=not train, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - rngs=rngs, - ) - - return outputs - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->XLMRoberta -class FlaxXLMRobertaModule(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - add_pooling_layer: bool = True - gradient_checkpointing: bool = False - - def setup(self): - self.embeddings = FlaxXLMRobertaEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxXLMRobertaEncoder( - self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.pooler = FlaxXLMRobertaPooler(self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # make sure `token_type_ids` is correctly initialized when not passed - if token_type_ids is None: - token_type_ids = jnp.zeros_like(input_ids) - - # make sure `position_ids` is correctly initialized when not passed - if position_ids is None: - position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) - - hidden_states = self.embeddings( - input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic - ) - outputs = self.encoder( - hidden_states, - attention_mask, - head_mask=head_mask, - deterministic=deterministic, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - pooled = self.pooler(hidden_states) if self.add_pooling_layer else None - - if not return_dict: - # if pooled is None, don't return it - if pooled is None: - return (hidden_states,) + outputs[1:] - return (hidden_states, pooled) + outputs[1:] - - return FlaxBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=hidden_states, - pooler_output=pooled, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - "The bare XLM RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", - XLM_ROBERTA_START_DOCSTRING, -) -class FlaxXLMRobertaModel(FlaxXLMRobertaPreTrainedModel): - module_class = FlaxXLMRobertaModule - - -append_call_sample_docstring(FlaxXLMRobertaModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC) - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMaskedLMModule with Roberta->XLMRoberta -class FlaxXLMRobertaForMaskedLMModule(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxXLMRobertaModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.lm_head = FlaxXLMRobertaLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxMaskedLMOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings("""XLM RoBERTa Model with a `language modeling` head on top.""", XLM_ROBERTA_START_DOCSTRING) -class FlaxXLMRobertaForMaskedLM(FlaxXLMRobertaPreTrainedModel): - module_class = FlaxXLMRobertaForMaskedLMModule - - -append_call_sample_docstring( - FlaxXLMRobertaForMaskedLM, - _CHECKPOINT_FOR_DOC, - FlaxBaseModelOutputWithPooling, - _CONFIG_FOR_DOC, - mask="", -) - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForSequenceClassificationModule with Roberta->XLMRoberta -class FlaxXLMRobertaForSequenceClassificationModule(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxXLMRobertaModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.classifier = FlaxXLMRobertaClassificationHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - logits = self.classifier(sequence_output, deterministic=deterministic) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxSequenceClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - XLM Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - XLM_ROBERTA_START_DOCSTRING, -) -class FlaxXLMRobertaForSequenceClassification(FlaxXLMRobertaPreTrainedModel): - module_class = FlaxXLMRobertaForSequenceClassificationModule - - -append_call_sample_docstring( - FlaxXLMRobertaForSequenceClassification, - _CHECKPOINT_FOR_DOC, - FlaxSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->XLMRoberta, with self.bert->self.roberta -class FlaxXLMRobertaForMultipleChoiceModule(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxXLMRobertaModule( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.classifier = nn.Dense(1, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - num_choices = input_ids.shape[1] - input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None - token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None - position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None - - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, deterministic=deterministic) - logits = self.classifier(pooled_output) - - reshaped_logits = logits.reshape(-1, num_choices) - - if not return_dict: - return (reshaped_logits,) + outputs[2:] - - return FlaxMultipleChoiceModelOutput( - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - XLM Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and - a softmax) e.g. for RocStories/SWAG tasks. - """, - XLM_ROBERTA_START_DOCSTRING, -) -class FlaxXLMRobertaForMultipleChoice(FlaxXLMRobertaPreTrainedModel): - module_class = FlaxXLMRobertaForMultipleChoiceModule - - -overwrite_call_docstring( - FlaxXLMRobertaForMultipleChoice, XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") -) -append_call_sample_docstring( - FlaxXLMRobertaForMultipleChoice, - _CHECKPOINT_FOR_DOC, - FlaxMultipleChoiceModelOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->XLMRoberta, with self.bert->self.roberta -class FlaxXLMRobertaForTokenClassificationModule(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxXLMRobertaModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - classifier_dropout = ( - self.config.classifier_dropout - if self.config.classifier_dropout is not None - else self.config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(rate=classifier_dropout) - self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.classifier(hidden_states) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxTokenClassifierOutput( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - XLM Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. - for Named-Entity-Recognition (NER) tasks. - """, - XLM_ROBERTA_START_DOCSTRING, -) -class FlaxXLMRobertaForTokenClassification(FlaxXLMRobertaPreTrainedModel): - module_class = FlaxXLMRobertaForTokenClassificationModule - - -append_call_sample_docstring( - FlaxXLMRobertaForTokenClassification, - _CHECKPOINT_FOR_DOC, - FlaxTokenClassifierOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->XLMRoberta, with self.bert->self.roberta -class FlaxXLMRobertaForQuestionAnsweringModule(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxXLMRobertaModule( - config=self.config, - dtype=self.dtype, - add_pooling_layer=False, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - logits = self.qa_outputs(hidden_states) - start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - return (start_logits, end_logits) + outputs[1:] - - return FlaxQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - XLM Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a - linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - XLM_ROBERTA_START_DOCSTRING, -) -class FlaxXLMRobertaForQuestionAnswering(FlaxXLMRobertaPreTrainedModel): - module_class = FlaxXLMRobertaForQuestionAnsweringModule - - -append_call_sample_docstring( - FlaxXLMRobertaForQuestionAnswering, - _CHECKPOINT_FOR_DOC, - FlaxQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) - - -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLMModule with Roberta->XLMRoberta -class FlaxXLMRobertaForCausalLMModule(nn.Module): - config: XLMRobertaConfig - dtype: jnp.dtype = jnp.float32 - gradient_checkpointing: bool = False - - def setup(self): - self.roberta = FlaxXLMRobertaModule( - config=self.config, - add_pooling_layer=False, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.lm_head = FlaxXLMRobertaLMHead(config=self.config, dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - token_type_ids: Optional[jnp.ndarray] = None, - head_mask: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_mask: Optional[jnp.ndarray] = None, - init_cache: bool = False, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - # Model - outputs = self.roberta( - input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.tie_word_embeddings: - shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] - else: - shared_embedding = None - - # Compute the prediction scores - logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) - - if not return_dict: - return (logits,) + outputs[1:] - - return FlaxCausalLMOutputWithCrossAttentions( - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -@add_start_docstrings( - """ - XLM Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for - autoregressive tasks. - """, - XLM_ROBERTA_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLM with Roberta->XLMRoberta -class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel): - module_class = FlaxXLMRobertaForCausalLMModule - - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): - # initializing the cache - batch_size, seq_length = input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyway. - # Thus, we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) - - return { - "past_key_values": past_key_values, - "attention_mask": extended_attention_mask, - "position_ids": position_ids, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - return model_kwargs - - -append_call_sample_docstring( - FlaxXLMRobertaForCausalLM, - _CHECKPOINT_FOR_DOC, - FlaxCausalLMOutputWithCrossAttentions, - _CONFIG_FOR_DOC, -) - - -__all__ = [ - "FlaxXLMRobertaForMaskedLM", - "FlaxXLMRobertaForCausalLM", - "FlaxXLMRobertaForMultipleChoice", - "FlaxXLMRobertaForQuestionAnswering", - "FlaxXLMRobertaForSequenceClassification", - "FlaxXLMRobertaForTokenClassification", - "FlaxXLMRobertaModel", - "FlaxXLMRobertaPreTrainedModel", -] diff --git a/src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py deleted file mode 100644 index 0def1bfdb00d..000000000000 --- a/src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py +++ /dev/null @@ -1,1790 +0,0 @@ -# coding=utf-8 -# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 XLM-RoBERTa model.""" - -from __future__ import annotations - -import math -import warnings - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutputWithPastAndCrossAttentions, - TFBaseModelOutputWithPoolingAndCrossAttentions, - TFCausalLMOutputWithCrossAttentions, - TFMaskedLMOutput, - TFMultipleChoiceModelOutput, - TFQuestionAnsweringModelOutput, - TFSequenceClassifierOutput, - TFTokenClassifierOutput, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFMaskedLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from .configuration_xlm_roberta import XLMRobertaConfig - - -logger = logging.get_logger(__name__) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "FacebookAI/xlm-roberta-base" -_CONFIG_FOR_DOC = "XLMRobertaConfig" - - -XLM_ROBERTA_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`XLMRobertaConfig`]): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -XLM_ROBERTA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See - [`PreTrainedTokenizer.__call__`] and [`PreTrainedTokenizer.encode`] for details. [What are input - IDs?](../glossary#input-ids) - attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) - head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaEmbeddings with Roberta->XLMRoberta -class TFXLMRobertaEmbeddings(keras.layers.Layer): - """ - Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. - """ - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.padding_idx = 1 - self.config = config - self.hidden_size = config.hidden_size - self.max_position_embeddings = config.max_position_embeddings - self.initializer_range = config.initializer_range - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - - def build(self, input_shape=None): - with tf.name_scope("word_embeddings"): - self.weight = self.add_weight( - name="weight", - shape=[self.config.vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("token_type_embeddings"): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.config.type_vocab_size, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - with tf.name_scope("position_embeddings"): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding - symbols are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - input_ids: tf.Tensor - Returns: tf.Tensor - """ - mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) - incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask - - return incremental_indices + self.padding_idx - - def call( - self, - input_ids=None, - position_ids=None, - token_type_ids=None, - inputs_embeds=None, - past_key_values_length=0, - training=False, - ): - """ - Applies embedding based on inputs tensor. - - Returns: - final_embeddings (`tf.Tensor`): output embedding tensor. - """ - assert not (input_ids is None and inputs_embeds is None) - - if input_ids is not None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = tf.gather(params=self.weight, indices=input_ids) - - input_shape = shape_list(inputs_embeds)[:-1] - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - if position_ids is None: - if input_ids is not None: - # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = self.create_position_ids_from_input_ids( - input_ids=input_ids, past_key_values_length=past_key_values_length - ) - else: - position_ids = tf.expand_dims( - tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 - ) - - position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) - token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) - final_embeddings = inputs_embeds + position_embeds + token_type_embeds - final_embeddings = self.LayerNorm(inputs=final_embeddings) - final_embeddings = self.dropout(inputs=final_embeddings, training=training) - - return final_embeddings - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->XLMRoberta -class TFXLMRobertaPooler(keras.layers.Layer): - def __init__(self, config: XLMRobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(inputs=first_token_tensor) - - return pooled_output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->XLMRoberta -class TFXLMRobertaSelfAttention(keras.layers.Layer): - def __init__(self, config: XLMRobertaConfig, **kwargs): - super().__init__(**kwargs) - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number " - f"of attention heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.sqrt_att_head_size = math.sqrt(self.attention_head_size) - - self.query = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" - ) - self.key = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" - ) - self.value = keras.layers.Dense( - units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" - ) - self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) - - self.is_decoder = config.is_decoder - self.config = config - - def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: - # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] - tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) - - # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] - return tf.transpose(tensor, perm=[0, 2, 1, 3]) - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(inputs=hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - key_layer = tf.concat([past_key_value[0], key_layer], axis=2) - value_layer = tf.concat([past_key_value[1], value_layer], axis=2) - else: - key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) - value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) - - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - # (batch size, num_heads, seq_len_q, seq_len_k) - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) - attention_scores = tf.divide(attention_scores, dk) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFXLMRobertaModel call() function) - attention_scores = tf.add(attention_scores, attention_mask) - - # Normalize the attention scores to probabilities. - attention_probs = stable_softmax(logits=attention_scores, axis=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(inputs=attention_probs, training=training) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = tf.multiply(attention_probs, head_mask) - - attention_output = tf.matmul(attention_probs, value_layer) - attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) - - # (batch_size, seq_len_q, all_head_size) - attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) - outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "query", None) is not None: - with tf.name_scope(self.query.name): - self.query.build([None, None, self.config.hidden_size]) - if getattr(self, "key", None) is not None: - with tf.name_scope(self.key.name): - self.key.build([None, None, self.config.hidden_size]) - if getattr(self, "value", None) is not None: - with tf.name_scope(self.value.name): - self.value.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->XLMRoberta -class TFXLMRobertaSelfOutput(keras.layers.Layer): - def __init__(self, config: XLMRobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->XLMRoberta -class TFXLMRobertaAttention(keras.layers.Layer): - def __init__(self, config: XLMRobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.self_attention = TFXLMRobertaSelfAttention(config, name="self") - self.dense_output = TFXLMRobertaSelfOutput(config, name="output") - - def prune_heads(self, heads): - raise NotImplementedError - - def call( - self, - input_tensor: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor, - encoder_attention_mask: tf.Tensor, - past_key_value: tuple[tf.Tensor], - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - self_outputs = self.self_attention( - hidden_states=input_tensor, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self.dense_output( - hidden_states=self_outputs[0], input_tensor=input_tensor, training=training - ) - # add attentions (possibly with past_key_value) if we output them - outputs = (attention_output,) + self_outputs[1:] - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "self_attention", None) is not None: - with tf.name_scope(self.self_attention.name): - self.self_attention.build(None) - if getattr(self, "dense_output", None) is not None: - with tf.name_scope(self.dense_output.name): - self.dense_output.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->XLMRoberta -class TFXLMRobertaIntermediate(keras.layers.Layer): - def __init__(self, config: XLMRobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) - else: - self.intermediate_act_fn = config.hidden_act - self.config = config - - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->XLMRoberta -class TFXLMRobertaOutput(keras.layers.Layer): - def __init__(self, config: XLMRobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.dense = keras.layers.Dense( - units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) - self.config = config - - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: - hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.dropout(inputs=hidden_states, training=training) - hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) - - return hidden_states - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.intermediate_size]) - if getattr(self, "LayerNorm", None) is not None: - with tf.name_scope(self.LayerNorm.name): - self.LayerNorm.build([None, None, self.config.hidden_size]) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->XLMRoberta -class TFXLMRobertaLayer(keras.layers.Layer): - def __init__(self, config: XLMRobertaConfig, **kwargs): - super().__init__(**kwargs) - - self.attention = TFXLMRobertaAttention(config, name="attention") - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TFXLMRobertaAttention(config, name="crossattention") - self.intermediate = TFXLMRobertaIntermediate(config, name="intermediate") - self.bert_output = TFXLMRobertaOutput(config, name="output") - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_value: tuple[tf.Tensor] | None, - output_attentions: bool, - training: bool = False, - ) -> tuple[tf.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - input_tensor=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=self_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - input_tensor=attention_output, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - training=training, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - intermediate_output = self.intermediate(hidden_states=attention_output) - layer_output = self.bert_output( - hidden_states=intermediate_output, input_tensor=attention_output, training=training - ) - outputs = (layer_output,) + outputs # add attentions if we output them - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attention", None) is not None: - with tf.name_scope(self.attention.name): - self.attention.build(None) - if getattr(self, "intermediate", None) is not None: - with tf.name_scope(self.intermediate.name): - self.intermediate.build(None) - if getattr(self, "bert_output", None) is not None: - with tf.name_scope(self.bert_output.name): - self.bert_output.build(None) - if getattr(self, "crossattention", None) is not None: - with tf.name_scope(self.crossattention.name): - self.crossattention.build(None) - - -# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->XLMRoberta -class TFXLMRobertaEncoder(keras.layers.Layer): - def __init__(self, config: XLMRobertaConfig, **kwargs): - super().__init__(**kwargs) - self.config = config - self.layer = [TFXLMRobertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] - - def call( - self, - hidden_states: tf.Tensor, - attention_mask: tf.Tensor, - head_mask: tf.Tensor, - encoder_hidden_states: tf.Tensor | None, - encoder_attention_mask: tf.Tensor | None, - past_key_values: tuple[tuple[tf.Tensor]] | None, - use_cache: bool | None, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - training: bool = False, - ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - past_key_value = past_key_values[i] if past_key_values is not None else None - - layer_outputs = layer_module( - hidden_states=hidden_states, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - training=training, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - if self.config.add_cross_attention and encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None - ) - - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - -@keras_serializable -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaMainLayer with Roberta->XLMRoberta -class TFXLMRobertaMainLayer(keras.layers.Layer): - config_class = XLMRobertaConfig - - def __init__(self, config, add_pooling_layer=True, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.is_decoder = config.is_decoder - - self.num_hidden_layers = config.num_hidden_layers - self.initializer_range = config.initializer_range - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.return_dict = config.use_return_dict - self.encoder = TFXLMRobertaEncoder(config, name="encoder") - self.pooler = TFXLMRobertaPooler(config, name="pooler") if add_pooling_layer else None - # The embeddings must be the last declaration in order to follow the weights order - self.embeddings = TFXLMRobertaEmbeddings(config, name="embeddings") - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings - def get_input_embeddings(self) -> keras.layers.Layer: - return self.embeddings - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings - def set_input_embeddings(self, value: tf.Variable): - self.embeddings.weight = value - self.embeddings.vocab_size = shape_list(value)[0] - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - raise NotImplementedError - - @unpack_inputs - # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: - if not self.config.is_decoder: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_key_values_length = 0 - past_key_values = [None] * len(self.encoder.layer) - else: - past_key_values_length = shape_list(past_key_values[0][0])[-2] - - if attention_mask is None: - attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - - if token_type_ids is None: - token_type_ids = tf.fill(dims=input_shape, value=0) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - training=training, - ) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(attention_mask) - - mask_seq_length = seq_length + past_key_values_length - # Copied from `modeling_tf_t5.py` - # Provided a padding mask of dimensions [batch_size, mask_seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if self.is_decoder: - seq_ids = tf.range(mask_seq_length) - causal_mask = tf.less_equal( - tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), - seq_ids[None, :, None], - ) - causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) - extended_attention_mask = causal_mask * attention_mask[:, None, :] - attention_mask_shape = shape_list(extended_attention_mask) - extended_attention_mask = tf.reshape( - extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) - ) - if past_key_values[0] is not None: - # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] - extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] - else: - extended_attention_mask = tf.reshape( - attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) - one_cst = tf.constant(1.0, dtype=embedding_output.dtype) - ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) - extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) - - # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.is_decoder and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) - num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) - if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, - # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) - - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.config.num_hidden_layers - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - - if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] - - return TFBaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "encoder", None) is not None: - with tf.name_scope(self.encoder.name): - self.encoder.build(None) - if getattr(self, "pooler", None) is not None: - with tf.name_scope(self.pooler.name): - self.pooler.build(None) - if getattr(self, "embeddings", None) is not None: - with tf.name_scope(self.embeddings.name): - self.embeddings.build(None) - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaPreTrainedModel with Roberta->XLMRoberta -class TFXLMRobertaPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = XLMRobertaConfig - base_model_prefix = "roberta" - - -@add_start_docstrings( - "The bare XLM RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", - XLM_ROBERTA_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaModel with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA -class TFXLMRobertaModel(TFXLMRobertaPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.roberta = TFXLMRobertaMainLayer(config, name="roberta") - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool | None = False, - ) -> tuple | TFBaseModelOutputWithPoolingAndCrossAttentions: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - """ - outputs = self.roberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->XLMRoberta -class TFXLMRobertaLMHead(keras.layers.Layer): - """XLMRoberta Head for masked language modeling.""" - - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.hidden_size = config.hidden_size - self.dense = keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" - ) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.act = get_tf_activation("gelu") - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = input_embeddings - - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.hidden_size]) - - def get_output_embeddings(self): - return self.decoder - - def set_output_embeddings(self, value): - self.decoder.weight = value - self.decoder.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.layer_norm(hidden_states) - - # project back to size of vocabulary with bias - seq_length = shape_list(tensor=hidden_states)[1] - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) - hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) - hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) - hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) - - return hidden_states - - -@add_start_docstrings("""XLM RoBERTa Model with a `language modeling` head on top.""", XLM_ROBERTA_START_DOCSTRING) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMaskedLM with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA -class TFXLMRobertaForMaskedLM(TFXLMRobertaPreTrainedModel, TFMaskedLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.roberta = TFXLMRobertaMainLayer(config, add_pooling_layer=False, name="roberta") - self.lm_head = TFXLMRobertaLMHead(config, self.roberta.embeddings, name="lm_head") - - def get_lm_head(self): - return self.lm_head - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - mask="", - expected_output="' Paris'", - expected_loss=0.1, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMaskedLMOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -@add_start_docstrings( - "XLM-RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.", - XLM_ROBERTA_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForCausalLM with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA -class TFXLMRobertaForCausalLM(TFXLMRobertaPreTrainedModel, TFCausalLanguageModelingLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] - - def __init__(self, config: XLMRobertaConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - if not config.is_decoder: - logger.warning("If you want to use `TFXLMRobertaLMHeadModel` as a standalone, add `is_decoder=True.`") - - self.roberta = TFXLMRobertaMainLayer(config, add_pooling_layer=False, name="roberta") - self.lm_head = TFXLMRobertaLMHead(config, input_embeddings=self.roberta.embeddings, name="lm_head") - - def get_lm_head(self): - return self.lm_head - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name - - # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = tf.ones(input_shape) - - # cut decoder_input_ids if past is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - encoder_hidden_states: np.ndarray | tf.Tensor | None = None, - encoder_attention_mask: np.ndarray | tf.Tensor | None = None, - past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]: - r""" - encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). Set to `False` during training, `True` during generation - labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - outputs = self.roberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = outputs[0] - logits = self.lm_head(hidden_states=sequence_output, training=training) - loss = None - - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutputWithCrossAttentions( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "lm_head", None) is not None: - with tf.name_scope(self.lm_head.name): - self.lm_head.build(None) - - -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->XLMRoberta -class TFXLMRobertaClassificationHead(keras.layers.Layer): - """Head for sentence-level classification tasks.""" - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.dense = keras.layers.Dense( - config.hidden_size, - kernel_initializer=get_initializer(config.initializer_range), - activation="tanh", - name="dense", - ) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.out_proj = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" - ) - self.config = config - - def call(self, features, training=False): - x = features[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x, training=training) - x = self.dense(x) - x = self.dropout(x, training=training) - x = self.out_proj(x) - return x - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "dense", None) is not None: - with tf.name_scope(self.dense.name): - self.dense.build([None, None, self.config.hidden_size]) - if getattr(self, "out_proj", None) is not None: - with tf.name_scope(self.out_proj.name): - self.out_proj.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - XLM RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - XLM_ROBERTA_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA -class TFXLMRobertaForSequenceClassification(TFXLMRobertaPreTrainedModel, TFSequenceClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta = TFXLMRobertaMainLayer(config, add_pooling_layer=False, name="roberta") - self.classifier = TFXLMRobertaClassificationHead(config, name="classifier") - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="cardiffnlp/twitter-roberta-base-emotion", - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="'optimism'", - expected_loss=0.08, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - logits = self.classifier(sequence_output, training=training) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build(None) - - -@add_start_docstrings( - """ - XLM Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and - a softmax) e.g. for RocStories/SWAG tasks. - """, - XLM_ROBERTA_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMultipleChoice with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA -class TFXLMRobertaForMultipleChoice(TFXLMRobertaPreTrainedModel, TFMultipleChoiceLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"lm_head"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.roberta = TFXLMRobertaMainLayer(config, name="roberta") - self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) - self.classifier = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward( - XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") - ) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - outputs = self.roberta( - flat_input_ids, - flat_attention_mask, - flat_token_type_ids, - flat_position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output, training=training) - logits = self.classifier(pooled_output) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFMultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - XLM RoBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. - for Named-Entity-Recognition (NER) tasks. - """, - XLM_ROBERTA_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForTokenClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA -class TFXLMRobertaForTokenClassification(TFXLMRobertaPreTrainedModel, TFTokenClassificationLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - _keys_to_ignore_on_load_missing = [r"dropout"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta = TFXLMRobertaMainLayer(config, add_pooling_layer=False, name="roberta") - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = keras.layers.Dropout(classifier_dropout) - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="ydshieh/roberta-large-ner-english", - output_type=TFTokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']", - expected_loss=0.01, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output, training=training) - logits = self.classifier(sequence_output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFTokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - XLM RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a - linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - XLM_ROBERTA_START_DOCSTRING, -) -# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForQuestionAnswering with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA -class TFXLMRobertaForQuestionAnswering(TFXLMRobertaPreTrainedModel, TFQuestionAnsweringLoss): - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.roberta = TFXLMRobertaMainLayer(config, add_pooling_layer=False, name="roberta") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint="ydshieh/roberta-base-squad2", - output_type=TFQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - expected_output="' puppet'", - expected_loss=0.86, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool | None = False, - ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - outputs = self.roberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - - loss = None - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TFQuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "roberta", None) is not None: - with tf.name_scope(self.roberta.name): - self.roberta.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFXLMRobertaForCausalLM", - "TFXLMRobertaForMaskedLM", - "TFXLMRobertaForMultipleChoice", - "TFXLMRobertaForQuestionAnswering", - "TFXLMRobertaForSequenceClassification", - "TFXLMRobertaForTokenClassification", - "TFXLMRobertaModel", - "TFXLMRobertaPreTrainedModel", -] diff --git a/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py deleted file mode 100755 index a15c5f22ad68..000000000000 --- a/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,113 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert BERT checkpoint.""" - -import argparse -import os - -import torch - -from transformers import ( - XLNetConfig, - XLNetForQuestionAnswering, - XLNetForSequenceClassification, - XLNetLMHeadModel, - load_tf_weights_in_xlnet, -) -from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging - - -GLUE_TASKS_NUM_LABELS = { - "cola": 2, - "mnli": 3, - "mrpc": 2, - "sst-2": 2, - "sts-b": 1, - "qqp": 2, - "qnli": 2, - "rte": 2, - "wnli": 2, -} - - -logging.set_verbosity_info() - - -def convert_xlnet_checkpoint_to_pytorch( - tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None -): - # Initialise PyTorch model - config = XLNetConfig.from_json_file(bert_config_file) - - finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" - if finetuning_task in GLUE_TASKS_NUM_LABELS: - print(f"Building PyTorch XLNetForSequenceClassification model from configuration: {config}") - config.finetuning_task = finetuning_task - config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] - model = XLNetForSequenceClassification(config) - elif "squad" in finetuning_task: - config.finetuning_task = finetuning_task - model = XLNetForQuestionAnswering(config) - else: - model = XLNetLMHeadModel(config) - - # Load weights from tf checkpoint - load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) - - # Save pytorch-model - pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) - pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) - print(f"Save PyTorch model to {os.path.abspath(pytorch_weights_dump_path)}") - torch.save(model.state_dict(), pytorch_weights_dump_path) - print(f"Save configuration file to {os.path.abspath(pytorch_config_dump_path)}") - with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: - f.write(config.to_json_string()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--xlnet_config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained XLNet model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_folder_path", - default=None, - type=str, - required=True, - help="Path to the folder to store the PyTorch model or dataset/vocab.", - ) - parser.add_argument( - "--finetuning_task", - default=None, - type=str, - help="Name of a task on which the XLNet TensorFlow model was fine-tuned", - ) - args = parser.parse_args() - print(args) - - convert_xlnet_checkpoint_to_pytorch( - args.tf_checkpoint_path, args.xlnet_config_file, args.pytorch_dump_folder_path, args.finetuning_task - ) diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py deleted file mode 100644 index 451d26c844d8..000000000000 --- a/src/transformers/models/xlnet/modeling_tf_xlnet.py +++ /dev/null @@ -1,1820 +0,0 @@ -# coding=utf-8 -# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -TF 2.0 XLNet model. -""" - -from __future__ import annotations - -import warnings -from dataclasses import dataclass - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFMultipleChoiceLoss, - TFPreTrainedModel, - TFQuestionAnsweringLoss, - TFSequenceClassificationLoss, - TFSequenceSummary, - TFSharedEmbeddings, - TFTokenClassificationLoss, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_xlnet import XLNetConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "xlnet/xlnet-base-cased" -_CONFIG_FOR_DOC = "XLNetConfig" - - -class TFXLNetRelativeAttention(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - if config.d_model % config.n_head != 0: - raise ValueError( - f"The hidden size ({config.d_model}) is not a multiple of the number of attention " - f"heads ({config.n_head}" - ) - - self.n_head = config.n_head - self.d_head = config.d_head - self.d_model = config.d_model - self.scale = 1 / (config.d_head**0.5) - self.initializer_range = config.initializer_range - self.output_attentions = config.output_attentions - - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.dropout = keras.layers.Dropout(config.dropout) - self.config = config - - def build(self, input_shape=None): - initializer = get_initializer(self.initializer_range) - self.q = self.add_weight( - shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="q" - ) - self.k = self.add_weight( - shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="k" - ) - self.v = self.add_weight( - shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="v" - ) - self.o = self.add_weight( - shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="o" - ) - self.r = self.add_weight( - shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="r" - ) - self.r_r_bias = self.add_weight( - shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias" - ) - self.r_s_bias = self.add_weight( - shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_s_bias" - ) - self.r_w_bias = self.add_weight( - shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias" - ) - self.seg_embed = self.add_weight( - shape=(2, self.n_head, self.d_head), initializer=initializer, trainable=True, name="seg_embed" - ) - - if self.built: - return - self.built = True - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - - def prune_heads(self, heads): - raise NotImplementedError - - def rel_shift(self, x, klen=-1): - """perform relative shift to form the relative attention score.""" - x_size = shape_list(x) - - x = tf.reshape(x, (x_size[1], x_size[0], x_size[2], x_size[3])) - x = x[1:, ...] - x = tf.reshape(x, (x_size[0], x_size[1] - 1, x_size[2], x_size[3])) - x = x[:, 0:klen, :, :] - # x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long)) - - return x - - def rel_attn_core( - self, q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions, training=False - ): - """Core relative positional attention operations.""" - # content based attention score - ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h) - - # position based attention score - bd = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_r_bias, k_head_r) - bd = self.rel_shift(bd, klen=shape_list(ac)[1]) - - # segment based attention score - if seg_mat is None: - ef = 0 - else: - ef = tf.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed) - ef = tf.einsum("ijbs,ibns->ijbn", seg_mat, ef) - - # merge attention scores and perform masking - attn_score = (ac + bd + ef) * self.scale - if attn_mask is not None: - # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask - if attn_mask.dtype == tf.float16 or attn_mask.dtype == tf.bfloat16: - attn_score = attn_score - 65500 * attn_mask - else: - attn_score = attn_score - 1e30 * attn_mask - - # attention probability - attn_prob = stable_softmax(attn_score, axis=1) - - attn_prob = self.dropout(attn_prob, training=training) - - # Mask heads if we want to - if head_mask is not None: - attn_prob = attn_prob * head_mask - - # attention output - attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h) - - if output_attentions: - return attn_vec, attn_prob - - return attn_vec - - def post_attention(self, h, attn_vec, residual=True, training=False): - """Post-attention processing.""" - # post-attention projection (back to `d_model`) - attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o) - - attn_out = self.dropout(attn_out, training=training) - - if residual: - attn_out = attn_out + h - output = self.layer_norm(attn_out) - - return output - - def call( - self, - h, - g, - attn_mask_h, - attn_mask_g, - r, - seg_mat, - mems: np.ndarray | tf.Tensor | None = None, - target_mapping: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = False, - training: bool = False, - ): - if g is not None: - # Two-stream attention with relative positional encoding. - # content based attention score - if mems is not None and len(shape_list(mems)) > 1: - cat = tf.concat([mems, h], axis=0) - else: - cat = h - - # content-based key head - k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k) - - # content-based value head - v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v) - - # position-based key head - k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r) - - # h-stream - # content-stream query head - q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q) - - # core attention ops - attn_vec_h = self.rel_attn_core( - q_head_h, - k_head_h, - v_head_h, - k_head_r, - seg_mat, - attn_mask_h, - head_mask, - output_attentions, - training=training, - ) - - if output_attentions: - attn_vec_h, attn_prob_h = attn_vec_h - - # post processing - output_h = self.post_attention(h, attn_vec_h, training=training) - - # g-stream - # query-stream query head - q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q) - - # core attention ops - if target_mapping is not None: - q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping) - attn_vec_g = self.rel_attn_core( - q_head_g, - k_head_h, - v_head_h, - k_head_r, - seg_mat, - attn_mask_g, - head_mask, - output_attentions, - training=training, - ) - - if output_attentions: - attn_vec_g, attn_prob_g = attn_vec_g - - attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping) - else: - attn_vec_g = self.rel_attn_core( - q_head_g, - k_head_h, - v_head_h, - k_head_r, - seg_mat, - attn_mask_g, - head_mask, - output_attentions, - training=training, - ) - - if output_attentions: - attn_vec_g, attn_prob_g = attn_vec_g - - # post processing - output_g = self.post_attention(g, attn_vec_g, training=training) - - if output_attentions: - attn_prob = attn_prob_h, attn_prob_g - - else: - # Multi-head attention with relative positional encoding - if mems is not None and len(shape_list(mems)) > 1: - cat = tf.concat([mems, h], axis=0) - else: - cat = h - - # content heads - q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q) - k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k) - v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v) - - # positional heads - k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r) - - # core attention ops - attn_vec = self.rel_attn_core( - q_head_h, - k_head_h, - v_head_h, - k_head_r, - seg_mat, - attn_mask_h, - head_mask, - output_attentions, - training=training, - ) - - if output_attentions: - attn_vec, attn_prob = attn_vec - - # post processing - output_h = self.post_attention(h, attn_vec, training=training) - output_g = None - - outputs = (output_h, output_g) - if output_attentions: - outputs = outputs + (attn_prob,) - return outputs - - -class TFXLNetFeedForward(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.layer_1 = keras.layers.Dense( - config.d_inner, kernel_initializer=get_initializer(config.initializer_range), name="layer_1" - ) - self.layer_2 = keras.layers.Dense( - config.d_model, kernel_initializer=get_initializer(config.initializer_range), name="layer_2" - ) - self.dropout = keras.layers.Dropout(config.dropout) - if isinstance(config.ff_activation, str): - self.activation_function = get_tf_activation(config.ff_activation) - else: - self.activation_function = config.ff_activation - self.config = config - - def call(self, inp, training=False): - output = inp - output = self.layer_1(output) - output = self.activation_function(output) - output = self.dropout(output, training=training) - output = self.layer_2(output) - output = self.dropout(output, training=training) - output = self.layer_norm(output + inp) - return output - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "layer_norm", None) is not None: - with tf.name_scope(self.layer_norm.name): - self.layer_norm.build([None, None, self.config.d_model]) - if getattr(self, "layer_1", None) is not None: - with tf.name_scope(self.layer_1.name): - self.layer_1.build([None, None, self.config.d_model]) - if getattr(self, "layer_2", None) is not None: - with tf.name_scope(self.layer_2.name): - self.layer_2.build([None, None, self.config.d_inner]) - - -class TFXLNetLayer(keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.rel_attn = TFXLNetRelativeAttention(config, name="rel_attn") - self.ff = TFXLNetFeedForward(config, name="ff") - self.dropout = keras.layers.Dropout(config.dropout) - - def call( - self, - output_h, - output_g, - non_tgt_mask, - attn_mask, - pos_emb, - seg_mat, - mems: np.ndarray | tf.Tensor | None = None, - target_mapping: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - output_attentions: bool | None = False, - training: bool = False, - ): - outputs = self.rel_attn( - output_h, - output_g, - non_tgt_mask, - attn_mask, - pos_emb, - seg_mat, - mems, - target_mapping, - head_mask, - output_attentions, - training=training, - ) - output_h, output_g = outputs[:2] - - if output_g is not None: - output_g = self.ff(output_g, training=training) - output_h = self.ff(output_h, training=training) - - outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "rel_attn", None) is not None: - with tf.name_scope(self.rel_attn.name): - self.rel_attn.build(None) - if getattr(self, "ff", None) is not None: - with tf.name_scope(self.ff.name): - self.ff.build(None) - - -class TFXLNetLMHead(keras.layers.Layer): - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - self.config = config - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.input_embeddings = input_embeddings - - def build(self, input_shape): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") - super().build(input_shape) - - def get_output_embeddings(self): - return self.input_embeddings - - def set_output_embeddings(self, value): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - hidden_states = self.input_embeddings(hidden_states, mode="linear") - hidden_states = hidden_states + self.bias - return hidden_states - - -@keras_serializable -class TFXLNetMainLayer(keras.layers.Layer): - config_class = XLNetConfig - - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - self.return_dict = config.return_dict - - self.mem_len = config.mem_len - self.reuse_len = config.reuse_len - self.d_model = config.d_model - self.same_length = config.same_length - self.attn_type = config.attn_type - self.bi_data = config.bi_data - self.clamp_len = config.clamp_len - self.n_layer = config.n_layer - self.use_bfloat16 = config.use_bfloat16 - self.initializer_range = config.initializer_range - - self.word_embedding = TFSharedEmbeddings( - config.vocab_size, config.d_model, initializer_range=config.initializer_range, name="word_embedding" - ) - self.layer = [TFXLNetLayer(config, name=f"layer_._{i}") for i in range(config.n_layer)] - self.dropout = keras.layers.Dropout(config.dropout) - - self.use_mems_eval = config.use_mems_eval - self.use_mems_train = config.use_mems_train - - def get_input_embeddings(self): - return self.word_embedding - - def set_input_embeddings(self, value): - self.word_embedding.weight = value - self.word_embedding.vocab_size = shape_list(value)[0] - - def build(self, input_shape=None): - initializer = get_initializer(self.initializer_range) - self.mask_emb = self.add_weight( - shape=(1, 1, self.d_model), initializer=initializer, trainable=True, name="mask_emb" - ) - - if self.built: - return - self.built = True - if getattr(self, "word_embedding", None) is not None: - with tf.name_scope(self.word_embedding.name): - self.word_embedding.build(None) - if getattr(self, "layer", None) is not None: - for layer in self.layer: - with tf.name_scope(layer.name): - layer.build(None) - - def _prune_heads(self, heads_to_prune): - raise NotImplementedError - - def create_mask(self, qlen, mlen): - """ - Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked. - - Args: - qlen: TODO Lysandre didn't fill - mlen: TODO Lysandre didn't fill - - ``` - - same_length=False: same_length=True: - < qlen > < qlen > - ^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1] - [0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1] - qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1] - [0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1] - v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0] - ``` - """ - attn_mask = tf.ones([qlen, qlen]) - mask_u = tf.linalg.band_part(attn_mask, 0, -1) - mask_dia = tf.linalg.band_part(attn_mask, 0, 0) - attn_mask_pad = tf.zeros([qlen, mlen]) - ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) - if self.same_length: - mask_l = tf.linalg.band_part(attn_mask, -1, 0) - ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1) - return ret - - def cache_mem(self, curr_out, prev_mem): - # cache hidden states into memory. - if self.reuse_len is not None and self.reuse_len > 0: - curr_out = curr_out[: self.reuse_len] - - if self.mem_len is None or self.mem_len == 0: - # If `use_mems` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time - # and returns all of the past and current hidden states. - cutoff = 0 - else: - # If `use_mems` is active and `mem_len` is defined, the model returns the last `mem_len` hidden - # states. This is the preferred setting for training and long-form generation. - cutoff = -self.mem_len - if prev_mem is None: - # if `use_mems` is active and `mem_len` is defined, the model - new_mem = curr_out[cutoff:] - else: - new_mem = tf.concat([prev_mem, curr_out], 0)[cutoff:] - - return tf.stop_gradient(new_mem) - - @staticmethod - def positional_embedding(pos_seq, inv_freq, bsz=None): - sinusoid_inp = tf.einsum("i,d->id", pos_seq, inv_freq) - pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], axis=-1) - pos_emb = pos_emb[:, None, :] - - if bsz is not None: - pos_emb = tf.tile(pos_emb, [1, bsz, 1]) - - return pos_emb - - def relative_positional_encoding(self, qlen, klen, bsz=None): - """create relative positional encoding.""" - freq_seq = tf.range(0, self.d_model, 2.0) - inv_freq = 1 / (10000 ** (freq_seq / self.d_model)) - - if self.attn_type == "bi": - # beg, end = klen - 1, -qlen - beg, end = klen, -qlen - elif self.attn_type == "uni": - # beg, end = klen - 1, -1 - beg, end = klen, -1 - else: - raise ValueError(f"Unknown `attn_type` {self.attn_type}.") - - if self.bi_data: - fwd_pos_seq = tf.range(beg, end, -1.0) - bwd_pos_seq = tf.range(-beg, -end, 1.0) - - if self.clamp_len > 0: - fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len) - bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len) - - if bsz is not None: - if bsz % 2 != 0: - raise ValueError(f"With bi_data, the batch size {bsz} should be divisible by 2") - fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2) - bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2) - else: - fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq) - bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq) - - pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1) - else: - fwd_pos_seq = tf.range(beg, end, -1.0) - if self.clamp_len > 0: - fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len) - pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz) - - return pos_emb - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - mems: np.ndarray | tf.Tensor | None = None, - perm_mask: np.ndarray | tf.Tensor | None = None, - target_mapping: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - input_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_mems: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ): - if training and use_mems is None: - use_mems = self.use_mems_train - else: - use_mems = self.use_mems_eval - - # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end - # but we want a unified interface in the library with the batch size on the first dimension - # so we move here the first dimension (batch) to the end - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_ids = tf.transpose(input_ids, perm=(1, 0)) - qlen, bsz = shape_list(input_ids)[:2] - elif inputs_embeds is not None: - inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2)) - qlen, bsz = shape_list(inputs_embeds)[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None - input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None - attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None - perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None - target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None - - mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0 - klen = mlen + qlen - - # Attention mask - # causal attention mask - if self.attn_type == "uni": - attn_mask = self.create_mask(qlen, mlen) - attn_mask = attn_mask[:, :, None, None] - elif self.attn_type == "bi": - attn_mask = None - else: - raise ValueError(f"Unsupported attention type: {self.attn_type}") - - # data mask: input mask & perm mask - assert input_mask is None or attention_mask is None, ( - "You can only use one of input_mask (uses 1 for padding) " - "or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one." - ) - if input_mask is None and attention_mask is not None: - one_cst = tf.constant(1.0) - input_mask = 1.0 - tf.cast(attention_mask, dtype=one_cst.dtype) - if input_mask is not None and perm_mask is not None: - data_mask = input_mask[None] + perm_mask - elif input_mask is not None and perm_mask is None: - data_mask = input_mask[None] - elif input_mask is None and perm_mask is not None: - data_mask = perm_mask - else: - data_mask = None - - if data_mask is not None: - # all mems can be attended to - if mlen > 0: - mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz]) - data_mask = tf.concat([mems_mask, data_mask], axis=1) - if attn_mask is None: - attn_mask = data_mask[:, :, :, None] - else: - attn_mask += data_mask[:, :, :, None] - - if attn_mask is not None: - attn_mask = tf.cast(attn_mask > 0, dtype=attn_mask.dtype) - - if attn_mask is not None: - non_tgt_mask = -tf.eye(qlen) - if mlen > 0: - non_tgt_mask = tf.concat([tf.zeros([qlen, mlen]), non_tgt_mask], axis=-1) - non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=non_tgt_mask.dtype) - else: - non_tgt_mask = None - - # Word embeddings and prepare h & g hidden states - if inputs_embeds is not None: - word_emb_k = inputs_embeds - else: - check_embeddings_within_bounds(input_ids, self.word_embedding.vocab_size) - word_emb_k = self.word_embedding(input_ids) - output_h = self.dropout(word_emb_k, training=training) - if target_mapping is not None: - word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1]) - # else: # We removed the inp_q input which was same as target mapping - # inp_q_ext = inp_q[:, :, None] - # word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k - output_g = self.dropout(word_emb_q, training=training) - else: - output_g = None - - # Segment embedding - if token_type_ids is not None: - # Convert `token_type_ids` to one-hot `seg_mat` - if mlen > 0: - mem_pad = tf.zeros([mlen, bsz], dtype=token_type_ids.dtype) - cat_ids = tf.concat([mem_pad, token_type_ids], 0) - else: - cat_ids = token_type_ids - - # `1` indicates not in the same segment [qlen x klen x bsz] - seg_mat = tf.cast( - tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), - dtype=token_type_ids.dtype, - ) - seg_mat = tf.one_hot(seg_mat, 2) - else: - seg_mat = None - - # Positional encoding - pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz) - pos_emb = self.dropout(pos_emb, training=training) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) - # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.n_layer - - new_mems = () - if mems is None: - mems = [None] * len(self.layer) - - attentions = [] if output_attentions else None - hidden_states = [] if output_hidden_states else None - for i, layer_module in enumerate(self.layer): - # cache new mems - if use_mems: - new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) - if output_hidden_states: - hidden_states.append((output_h, output_g) if output_g is not None else output_h) - - outputs = layer_module( - output_h, - output_g, - non_tgt_mask, - attn_mask, - pos_emb, - seg_mat, - mems[i], - target_mapping, - head_mask[i], - output_attentions, - training=training, - ) - output_h, output_g = outputs[:2] - if output_attentions: - attentions.append(outputs[2]) - - # Add last hidden state - if output_hidden_states: - hidden_states.append((output_h, output_g) if output_g is not None else output_h) - - output = self.dropout(output_g if output_g is not None else output_h, training=training) - - # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) - output = tf.transpose(output, perm=(1, 0, 2)) - - if not use_mems: - new_mems = None - if output_hidden_states: - if output_g is not None: - hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs) - else: - hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states) - if output_attentions: - if target_mapping is not None: - # when target_mapping is provided, there are 2-tuple of attentions - attentions = tuple( - tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions - ) - else: - attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) - - if not return_dict: - return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None) - - return TFXLNetModelOutput( - last_hidden_state=output, mems=new_mems, hidden_states=hidden_states, attentions=attentions - ) - - -class TFXLNetPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = XLNetConfig - base_model_prefix = "transformer" - - -@dataclass -class TFXLNetModelOutput(ModelOutput): - """ - Output type of [`TFXLNetModel`]. - - Args: - last_hidden_state (`tf.Tensor` of shape `(batch_size, num_predict, hidden_size)`): - Sequence of hidden-states at the last layer of the model. - - `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` - corresponds to `sequence_length`. - mems (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The - token ids which have their past given to this model should not be passed as `input_ids` as they have - already been computed. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: tf.Tensor | None = None - mems: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFXLNetLMHeadModelOutput(ModelOutput): - """ - Output type of [`TFXLNetLMHeadModel`]. - - Args: - loss (`tf.Tensor` of shape *(1,)*, *optional*, returned when `labels` is provided) - Language modeling loss (for next-token prediction). - logits (`tf.Tensor` of shape `(batch_size, num_predict, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - - `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` - corresponds to `sequence_length`. - mems (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The - token ids which have their past given to this model should not be passed as `input_ids` as they have - already been computed. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - mems: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFXLNetForSequenceClassificationOutput(ModelOutput): - """ - Output type of [`TFXLNetForSequenceClassification`]. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `label` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - mems (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The - token ids which have their past given to this model should not be passed as `input_ids` as they have - already been computed. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - mems: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFXLNetForTokenClassificationOutput(ModelOutput): - """ - Output type of [`TFXLNetForTokenClassificationOutput`]. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : - Classification loss. - logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`): - Classification scores (before SoftMax). - mems (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The - token ids which have their past given to this model should not be passed as `input_ids` as they have - already been computed. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - mems: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFXLNetForMultipleChoiceOutput(ModelOutput): - """ - Output type of [`TFXLNetForMultipleChoice`]. - - Args: - loss (`tf.Tensor` of shape *(1,)*, *optional*, returned when `labels` is provided): - Classification loss. - logits (`tf.Tensor` of shape `(batch_size, num_choices)`): - *num_choices* is the second dimension of the input tensors. (see *input_ids* above). - - Classification scores (before SoftMax). - mems (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The - token ids which have their past given to this model should not be passed as `input_ids` as they have - already been computed. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - logits: tf.Tensor | None = None - mems: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -@dataclass -class TFXLNetForQuestionAnsweringSimpleOutput(ModelOutput): - """ - Output type of [`TFXLNetForQuestionAnsweringSimple`]. - - Args: - loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. - start_logits (`tf.Tensor` of shape `(batch_size, sequence_length,)`): - Span-start scores (before SoftMax). - end_logits (`tf.Tensor` of shape `(batch_size, sequence_length,)`): - Span-end scores (before SoftMax). - mems (`list[tf.Tensor]` of length `config.n_layers`): - Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The - token ids which have their past given to this model should not be passed as `input_ids` as they have - already been computed. - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: tf.Tensor | None = None - start_logits: tf.Tensor | None = None - end_logits: tf.Tensor | None = None - mems: list[tf.Tensor] | None = None - hidden_states: tuple[tf.Tensor, ...] | None = None - attentions: tuple[tf.Tensor, ...] | None = None - - -XLNET_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`XLNetConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -XLNET_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - mems (`list[torch.FloatTensor]` of length `config.n_layers`): - Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential - decoding. The token ids which have their past given to this model should not be passed as `input_ids` as - they have already been computed. - - `use_mems` has to be set to `True` to make use of `mems`. - perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*): - Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`: - - - if `perm_mask[k, i, j] = 0`, i attend to j in batch k; - - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k. - - If not set, each token attends to all the others (full bidirectional attention). Only used during - pretraining (to define factorization order) or for sequential decoding (generation). - target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*): - Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is - on the j-th token. Only used during pretraining for partial prediction or for sequential decoding - (generation). - token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - input_mask (`torch.FloatTensor` of shape `{0}`, *optional*): - Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for - real tokens and 1 for padding which is kept for compatibility with the original code base. - - Mask values selected in `[0, 1]`: - - - 1 for tokens that are **masked**, - - 0 for tokens that are **not masked**. - - You can only uses one of `input_mask` and `attention_mask`. - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.", - XLNET_START_DOCSTRING, -) -class TFXLNetModel(TFXLNetPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFXLNetMainLayer(config, name="transformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFXLNetModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - mems: np.ndarray | tf.Tensor | None = None, - perm_mask: np.ndarray | tf.Tensor | None = None, - target_mapping: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - input_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_mems: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - training: bool = False, - ) -> TFXLNetModelOutput | tuple[tf.Tensor]: - outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - mems=mems, - perm_mask=perm_mask, - target_mapping=target_mapping, - token_type_ids=token_type_ids, - input_mask=input_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_mems=use_mems, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -@add_start_docstrings( - """ - XLNet Model with a language modeling head on top (linear layer with weights tied to the input embeddings). - """, - XLNET_START_DOCSTRING, -) -class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFXLNetMainLayer(config, name="transformer") - self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name="lm_loss") - # generate fails to convert to a graph with XLNet - self.supports_xla_generation = False - - def get_lm_head(self): - return self.lm_loss - - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_loss.name - - def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_mems=None, **kwargs): - # Add dummy token at the end (no attention on this one) - effective_batch_size = inputs.shape[0] - dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype) - - # At every pass, the attention values for the new token and the two last generated tokens - # are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have - # offset = 1; offset = 2 seems to have slightly better computation. - offset = 2 - - if past_key_values: - input_ids = tf.concat([inputs[:, -offset:], dummy_token], axis=1) - else: - input_ids = tf.concat([inputs, dummy_token], axis=1) - - # Build permutation mask so that previous tokens don't see last token - sequence_length = input_ids.shape[1] - perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1)) - perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1)) - perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1) - - # We'll only predict the last token - target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1)) - target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1)) - target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1) - - inputs = { - "input_ids": input_ids, - "perm_mask": perm_mask, - "target_mapping": target_mapping, - "use_mems": use_mems, - } - - # if past is defined in model kwargs then use it for faster decoding - if past_key_values: - inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past_key_values) - - return inputs - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=TFXLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - mems: np.ndarray | tf.Tensor | None = None, - perm_mask: np.ndarray | tf.Tensor | None = None, - target_mapping: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - input_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_mems: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFXLNetLMHeadModelOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - - Return: - - Examples: - - ```python - >>> import tensorflow as tf - >>> import numpy as np - >>> from transformers import AutoTokenizer, TFXLNetLMHeadModel - - >>> tokenizer = AutoTokenizer.from_pretrained("xlnet/xlnet-large-cased") - >>> model = TFXLNetLMHeadModel.from_pretrained("xlnet/xlnet-large-cased") - - >>> # We show how to setup inputs to predict a next token using a bi-directional context. - >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is very ", add_special_tokens=True))[ - ... None, : - ... ] # We will predict the masked token - - >>> perm_mask = np.zeros((1, input_ids.shape[1], input_ids.shape[1])) - >>> perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token - - >>> target_mapping = np.zeros( - ... (1, 1, input_ids.shape[1]) - ... ) # Shape [1, 1, seq_length] => let's predict one token - >>> target_mapping[ - ... 0, 0, -1 - ... ] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token) - - >>> outputs = model( - ... input_ids, - ... perm_mask=tf.constant(perm_mask, dtype=tf.float32), - ... target_mapping=tf.constant(target_mapping, dtype=tf.float32), - ... ) - - >>> next_token_logits = outputs[ - ... 0 - ... ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size] - ```""" - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - mems=mems, - perm_mask=perm_mask, - target_mapping=target_mapping, - token_type_ids=token_type_ids, - input_mask=input_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_mems=use_mems, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_state = transformer_outputs[0] - logits = self.lm_loss(hidden_state, training=training) - - loss = None - if labels is not None: - loss = self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFXLNetLMHeadModelOutput( - loss=loss, - logits=logits, - mems=transformer_outputs.mems, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "lm_loss", None) is not None: - with tf.name_scope(self.lm_loss.name): - self.lm_loss.build(None) - - -@add_start_docstrings( - """ - XLNet Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. - for GLUE tasks. - """, - XLNET_START_DOCSTRING, -) -class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.transformer = TFXLNetMainLayer(config, name="transformer") - self.sequence_summary = TFSequenceSummary( - config, initializer_range=config.initializer_range, name="sequence_summary" - ) - self.logits_proj = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFXLNetForSequenceClassificationOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - mems: np.ndarray | tf.Tensor | None = None, - perm_mask: np.ndarray | tf.Tensor | None = None, - target_mapping: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - input_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_mems: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFXLNetForSequenceClassificationOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - mems=mems, - perm_mask=perm_mask, - target_mapping=target_mapping, - token_type_ids=token_type_ids, - input_mask=input_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_mems=use_mems, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - output = transformer_outputs[0] - - output = self.sequence_summary(output) - logits = self.logits_proj(output) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFXLNetForSequenceClassificationOutput( - loss=loss, - logits=logits, - mems=transformer_outputs.mems, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "sequence_summary", None) is not None: - with tf.name_scope(self.sequence_summary.name): - self.sequence_summary.build(None) - if getattr(self, "logits_proj", None) is not None: - with tf.name_scope(self.logits_proj.name): - self.logits_proj.build([None, None, self.config.d_model]) - - -@add_start_docstrings( - """ - XLNET Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - XLNET_START_DOCSTRING, -) -class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.transformer = TFXLNetMainLayer(config, name="transformer") - self.sequence_summary = TFSequenceSummary( - config, initializer_range=config.initializer_range, name="sequence_summary" - ) - self.logits_proj = keras.layers.Dense( - 1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFXLNetForMultipleChoiceOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - input_mask: np.ndarray | tf.Tensor | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - mems: np.ndarray | tf.Tensor | None = None, - perm_mask: np.ndarray | tf.Tensor | None = None, - target_mapping: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_mems: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFXLNetForMultipleChoiceOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) - """ - - if input_ids is not None: - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] - else: - num_choices = shape_list(inputs_embeds)[1] - seq_length = shape_list(inputs_embeds)[2] - - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None - flat_inputs_embeds = ( - tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) - if inputs_embeds is not None - else None - ) - transformer_outputs = self.transformer( - flat_input_ids, - flat_attention_mask, - mems, - perm_mask, - target_mapping, - flat_token_type_ids, - flat_input_mask, - head_mask, - flat_inputs_embeds, - use_mems, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - output = transformer_outputs[0] - logits = self.sequence_summary(output) - logits = self.logits_proj(logits) - reshaped_logits = tf.reshape(logits, (-1, num_choices)) - loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - - if not return_dict: - output = (reshaped_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFXLNetForMultipleChoiceOutput( - loss=loss, - logits=reshaped_logits, - mems=transformer_outputs.mems, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "sequence_summary", None) is not None: - with tf.name_scope(self.sequence_summary.name): - self.sequence_summary.build(None) - if getattr(self, "logits_proj", None) is not None: - with tf.name_scope(self.logits_proj.name): - self.logits_proj.build([None, None, self.config.d_model]) - - -@add_start_docstrings( - """ - XLNet Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - XLNET_START_DOCSTRING, -) -class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - - self.transformer = TFXLNetMainLayer(config, name="transformer") - self.classifier = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFXLNetForTokenClassificationOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - mems: np.ndarray | tf.Tensor | None = None, - perm_mask: np.ndarray | tf.Tensor | None = None, - target_mapping: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - input_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_mems: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - labels: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFXLNetForTokenClassificationOutput | tuple[tf.Tensor]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - mems=mems, - perm_mask=perm_mask, - target_mapping=target_mapping, - token_type_ids=token_type_ids, - input_mask=input_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_mems=use_mems, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - output = transformer_outputs[0] - logits = self.classifier(output) - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFXLNetForTokenClassificationOutput( - loss=loss, - logits=logits, - mems=transformer_outputs.mems, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "classifier", None) is not None: - with tf.name_scope(self.classifier.name): - self.classifier.build([None, None, self.config.hidden_size]) - - -@add_start_docstrings( - """ - XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - XLNET_START_DOCSTRING, -) -class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnsweringLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFXLNetMainLayer(config, name="transformer") - self.qa_outputs = keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" - ) - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFXLNetForQuestionAnsweringSimpleOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - mems: np.ndarray | tf.Tensor | None = None, - perm_mask: np.ndarray | tf.Tensor | None = None, - target_mapping: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - input_mask: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - use_mems: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - start_positions: np.ndarray | tf.Tensor | None = None, - end_positions: np.ndarray | tf.Tensor | None = None, - training: bool = False, - ) -> TFXLNetForQuestionAnsweringSimpleOutput | tuple[tf.Tensor]: - r""" - start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - mems=mems, - perm_mask=perm_mask, - target_mapping=target_mapping, - token_type_ids=token_type_ids, - input_mask=input_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_mems=use_mems, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - sequence_output = transformer_outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = tf.split(logits, 2, axis=-1) - start_logits = tf.squeeze(start_logits, axis=-1) - end_logits = tf.squeeze(end_logits, axis=-1) - - loss = None - if start_positions is not None and end_positions is not None: - labels = {"start_position": start_positions} - labels["end_position"] = end_positions - loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - - if not return_dict: - output = (start_logits, end_logits) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFXLNetForQuestionAnsweringSimpleOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - mems=transformer_outputs.mems, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "qa_outputs", None) is not None: - with tf.name_scope(self.qa_outputs.name): - self.qa_outputs.build([None, None, self.config.hidden_size]) - - -__all__ = [ - "TFXLNetForMultipleChoice", - "TFXLNetForQuestionAnsweringSimple", - "TFXLNetForSequenceClassification", - "TFXLNetForTokenClassification", - "TFXLNetLMHeadModel", - "TFXLNetMainLayer", - "TFXLNetModel", - "TFXLNetPreTrainedModel", -] diff --git a/src/transformers/optimization_tf.py b/src/transformers/optimization_tf.py deleted file mode 100644 index 71a77251f2bf..000000000000 --- a/src/transformers/optimization_tf.py +++ /dev/null @@ -1,378 +0,0 @@ -# Copyright 2019 The TensorFlow Authors, The Hugging Face Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functions and classes related to optimization (weight updates).""" - -from typing import Callable, Optional, Union - -import tensorflow as tf - - -try: - from tf_keras.optimizers.legacy import Adam -except (ImportError, ModuleNotFoundError): - from tensorflow.keras.optimizers.legacy import Adam - -from .modeling_tf_utils import keras - - -# This block because Keras loves randomly moving things to different places - this changed somewhere between 2.10 - 2.15 -if hasattr(keras.optimizers.schedules, "learning_rate_schedule"): - schedules = keras.optimizers.schedules.learning_rate_schedule -else: - schedules = keras.optimizers.schedules - - -class WarmUp(schedules.LearningRateSchedule): - """ - Applies a warmup schedule on a given learning rate decay schedule. - - Args: - initial_learning_rate (`float`): - The initial learning rate for the schedule after the warmup (so this will be the learning rate at the end - of the warmup). - decay_schedule_fn (`Callable`): - The schedule function to apply after the warmup for the rest of training. - warmup_steps (`int`): - The number of steps for the warmup part of training. - power (`float`, *optional*, defaults to 1.0): - The power to use for the polynomial warmup (defaults is a linear warmup). - name (`str`, *optional*): - Optional name prefix for the returned tensors during the schedule. - """ - - def __init__( - self, - initial_learning_rate: float, - decay_schedule_fn: Callable, - warmup_steps: int, - power: float = 1.0, - name: Optional[str] = None, - ): - super().__init__() - self.initial_learning_rate = initial_learning_rate - self.warmup_steps = warmup_steps - self.power = power - self.decay_schedule_fn = decay_schedule_fn - self.name = name - - def __call__(self, step): - with tf.name_scope(self.name or "WarmUp") as name: - # Implements polynomial warmup. i.e., if global_step < warmup_steps, the - # learning rate will be `global_step/num_warmup_steps * init_lr`. - global_step_float = tf.cast(step, tf.float32) - warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) - warmup_percent_done = global_step_float / warmup_steps_float - warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power) - return tf.cond( - global_step_float < warmup_steps_float, - lambda: warmup_learning_rate, - lambda: self.decay_schedule_fn(step - self.warmup_steps), - name=name, - ) - - def get_config(self): - return { - "initial_learning_rate": self.initial_learning_rate, - "decay_schedule_fn": self.decay_schedule_fn, - "warmup_steps": self.warmup_steps, - "power": self.power, - "name": self.name, - } - - -def create_optimizer( - init_lr: float, - num_train_steps: int, - num_warmup_steps: int, - min_lr_ratio: float = 0.0, - adam_beta1: float = 0.9, - adam_beta2: float = 0.999, - adam_epsilon: float = 1e-8, - adam_clipnorm: Optional[float] = None, - adam_global_clipnorm: Optional[float] = None, - weight_decay_rate: float = 0.0, - power: float = 1.0, - include_in_weight_decay: Optional[list[str]] = None, -): - """ - Creates an optimizer with a learning rate schedule using a warmup phase followed by a linear decay. - - Args: - init_lr (`float`): - The desired learning rate at the end of the warmup phase. - num_train_steps (`int`): - The total number of training steps. - num_warmup_steps (`int`): - The number of warmup steps. - min_lr_ratio (`float`, *optional*, defaults to 0): - The final learning rate at the end of the linear decay will be `init_lr * min_lr_ratio`. - adam_beta1 (`float`, *optional*, defaults to 0.9): - The beta1 to use in Adam. - adam_beta2 (`float`, *optional*, defaults to 0.999): - The beta2 to use in Adam. - adam_epsilon (`float`, *optional*, defaults to 1e-8): - The epsilon to use in Adam. - adam_clipnorm (`float`, *optional*, defaults to `None`): - If not `None`, clip the gradient norm for each weight tensor to this value. - adam_global_clipnorm (`float`, *optional*, defaults to `None`) - If not `None`, clip gradient norm to this value. When using this argument, the norm is computed over all - weight tensors, as if they were concatenated into a single vector. - weight_decay_rate (`float`, *optional*, defaults to 0): - The weight decay to use. - power (`float`, *optional*, defaults to 1.0): - The power to use for PolynomialDecay. - include_in_weight_decay (`list[str]`, *optional*): - List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is - applied to all parameters except bias and layer norm parameters. - """ - # Implements linear decay of the learning rate. - lr_schedule = schedules.PolynomialDecay( - initial_learning_rate=init_lr, - decay_steps=num_train_steps - num_warmup_steps, - end_learning_rate=init_lr * min_lr_ratio, - power=power, - ) - if num_warmup_steps: - lr_schedule = WarmUp( - initial_learning_rate=init_lr, - decay_schedule_fn=lr_schedule, - warmup_steps=num_warmup_steps, - ) - if weight_decay_rate > 0.0: - optimizer = AdamWeightDecay( - learning_rate=lr_schedule, - weight_decay_rate=weight_decay_rate, - beta_1=adam_beta1, - beta_2=adam_beta2, - epsilon=adam_epsilon, - clipnorm=adam_clipnorm, - global_clipnorm=adam_global_clipnorm, - exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], - include_in_weight_decay=include_in_weight_decay, - ) - else: - optimizer = keras.optimizers.Adam( - learning_rate=lr_schedule, - beta_1=adam_beta1, - beta_2=adam_beta2, - epsilon=adam_epsilon, - clipnorm=adam_clipnorm, - global_clipnorm=adam_global_clipnorm, - ) - # We return the optimizer and the LR scheduler in order to better track the - # evolution of the LR independently of the optimizer. - return optimizer, lr_schedule - - -class AdamWeightDecay(Adam): - """ - Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the - loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact - with the m and v parameters in strange ways as shown in [Decoupled Weight Decay - Regularization](https://huggingface.co/papers/1711.05101). - - Instead we want to decay the weights in a manner that doesn't interact with the m/v parameters. This is equivalent - to adding the square of the weights to the loss with plain (non-momentum) SGD. - - Args: - learning_rate (`Union[float, LearningRateSchedule]`, *optional*, defaults to 0.001): - The learning rate to use or a schedule. - beta_1 (`float`, *optional*, defaults to 0.9): - The beta1 parameter in Adam, which is the exponential decay rate for the 1st momentum estimates. - beta_2 (`float`, *optional*, defaults to 0.999): - The beta2 parameter in Adam, which is the exponential decay rate for the 2nd momentum estimates. - epsilon (`float`, *optional*, defaults to 1e-07): - The epsilon parameter in Adam, which is a small constant for numerical stability. - amsgrad (`bool`, *optional*, defaults to `False`): - Whether to apply AMSGrad variant of this algorithm or not, see [On the Convergence of Adam and - Beyond](https://huggingface.co/papers/1904.09237). - weight_decay_rate (`float`, *optional*, defaults to 0.0): - The weight decay to apply. - include_in_weight_decay (`list[str]`, *optional*): - List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is - applied to all parameters by default (unless they are in `exclude_from_weight_decay`). - exclude_from_weight_decay (`list[str]`, *optional*): - List of the parameter names (or re patterns) to exclude from applying weight decay to. If a - `include_in_weight_decay` is passed, the names in it will supersede this list. - name (`str`, *optional*, defaults to `"AdamWeightDecay"`): - Optional name for the operations created when applying gradients. - kwargs (`dict[str, Any]`, *optional*): - Keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by - norm; `clipvalue` is clip gradients by value, `decay` is included for backward compatibility to allow time - inverse decay of learning rate. `lr` is included for backward compatibility, recommended to use - `learning_rate` instead. - """ - - def __init__( - self, - learning_rate: Union[float, schedules.LearningRateSchedule] = 0.001, - beta_1: float = 0.9, - beta_2: float = 0.999, - epsilon: float = 1e-7, - amsgrad: bool = False, - weight_decay_rate: float = 0.0, - include_in_weight_decay: Optional[list[str]] = None, - exclude_from_weight_decay: Optional[list[str]] = None, - name: str = "AdamWeightDecay", - **kwargs, - ): - super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs) - self.weight_decay_rate = weight_decay_rate - self._include_in_weight_decay = include_in_weight_decay - self._exclude_from_weight_decay = exclude_from_weight_decay - - @classmethod - def from_config(cls, config): - """Creates an optimizer from its config with WarmUp custom object.""" - custom_objects = {"WarmUp": WarmUp} - return super().from_config(config, custom_objects=custom_objects) - - def _prepare_local(self, var_device, var_dtype, apply_state): - super()._prepare_local(var_device, var_dtype, apply_state) - apply_state[(var_device, var_dtype)]["weight_decay_rate"] = tf.constant( - self.weight_decay_rate, name="adam_weight_decay_rate" - ) - - def _decay_weights_op(self, var, learning_rate, apply_state): - do_decay = self._do_use_weight_decay(var.name) - if do_decay: - return var.assign_sub( - learning_rate * var * apply_state[(var.device, var.dtype.base_dtype)]["weight_decay_rate"], - use_locking=self._use_locking, - ) - return tf.no_op() - - def apply_gradients(self, grads_and_vars, name=None, **kwargs): - grads, tvars = list(zip(*grads_and_vars)) - return super().apply_gradients(zip(grads, tvars), name=name, **kwargs) - - def _get_lr(self, var_device, var_dtype, apply_state): - """Retrieves the learning rate with the given state.""" - if apply_state is None: - return self._decayed_lr_t[var_dtype], {} - - apply_state = apply_state or {} - coefficients = apply_state.get((var_device, var_dtype)) - if coefficients is None: - coefficients = self._fallback_apply_state(var_device, var_dtype) - apply_state[(var_device, var_dtype)] = coefficients - - return coefficients["lr_t"], {"apply_state": apply_state} - - def _resource_apply_dense(self, grad, var, apply_state=None): - lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) - decay = self._decay_weights_op(var, lr_t, apply_state) - with tf.control_dependencies([decay]): - return super()._resource_apply_dense(grad, var, **kwargs) - - def _resource_apply_sparse(self, grad, var, indices, apply_state=None): - lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) - decay = self._decay_weights_op(var, lr_t, apply_state) - with tf.control_dependencies([decay]): - return super()._resource_apply_sparse(grad, var, indices, **kwargs) - - def get_config(self): - config = super().get_config() - config.update({"weight_decay_rate": self.weight_decay_rate}) - return config - - def _do_use_weight_decay(self, param_name): - """Whether to use L2 weight decay for `param_name`.""" - if self.weight_decay_rate == 0: - return False - - if self._include_in_weight_decay: - for r in self._include_in_weight_decay: - if r in param_name: - return True - - if self._exclude_from_weight_decay: - for r in self._exclude_from_weight_decay: - if r in param_name: - return False - return True - - -# Extracted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py -class GradientAccumulator: - """ - Gradient accumulation utility. When used with a distribution strategy, the accumulator should be called in a - replica context. Gradients will be accumulated locally on each replica and without synchronization. Users should - then call `.gradients`, scale the gradients if required, and pass the result to `apply_gradients`. - """ - - # We use the ON_READ synchronization policy so that no synchronization is - # performed on assignment. To get the value, we call .value() which returns the - # value on the current replica without synchronization. - - def __init__(self): - """Initializes the accumulator.""" - self._gradients = [] - self._accum_steps = None - - @property - def step(self): - """Number of accumulated steps.""" - if self._accum_steps is None: - self._accum_steps = tf.Variable( - tf.constant(0, dtype=tf.int64), - trainable=False, - synchronization=tf.VariableSynchronization.ON_READ, - aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, - ) - - return self._accum_steps.value() - - @property - def gradients(self): - """The accumulated gradients on the current replica.""" - if not self._gradients: - raise ValueError("The accumulator should be called first to initialize the gradients") - return [gradient.value() if gradient is not None else gradient for gradient in self._gradients] - - def __call__(self, gradients): - """Accumulates `gradients` on the current replica.""" - if not self._gradients: - _ = self.step # Create the step variable. - self._gradients.extend( - [ - tf.Variable( - tf.zeros_like(gradient), - trainable=False, - synchronization=tf.VariableSynchronization.ON_READ, - aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, - ) - if gradient is not None - else gradient - for gradient in gradients - ] - ) - if len(gradients) != len(self._gradients): - raise ValueError(f"Expected {len(self._gradients)} gradients, but got {len(gradients)}") - - for accum_gradient, gradient in zip(self._gradients, gradients): - if accum_gradient is not None and gradient is not None: - accum_gradient.assign_add(gradient) - - self._accum_steps.assign_add(1) - - def reset(self): - """Resets the accumulated gradients on the current replica.""" - if not self._gradients: - return - self._accum_steps.assign(0) - for gradient in self._gradients: - if gradient is not None: - gradient.assign(tf.zeros_like(gradient)) diff --git a/src/transformers/tf_utils.py b/src/transformers/tf_utils.py deleted file mode 100644 index 11d07f8d7eda..000000000000 --- a/src/transformers/tf_utils.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Union - -import numpy as np -import tensorflow as tf - -from .feature_extraction_utils import BatchFeature -from .tokenization_utils_base import BatchEncoding -from .utils import logging - - -logger = logging.get_logger(__name__) - - -def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> list[int]: - """ - Deal with dynamic shape in tensorflow cleanly. - - Args: - tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of. - - Returns: - `list[int]`: The shape of the tensor as a list. - """ - if isinstance(tensor, np.ndarray): - return list(tensor.shape) - - dynamic = tf.shape(tensor) - - if tensor.shape == tf.TensorShape(None): - return dynamic - - static = tensor.shape.as_list() - - return [dynamic[i] if s is None else s for i, s in enumerate(static)] - - -def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional[str] = None) -> tf.Tensor: - """ - Stable wrapper that returns the same output as `tf.nn.softmax`, but that works reliably with XLA on CPU. It is - meant as a workaround for the [following issue](https://github.com/tensorflow/tensorflow/issues/55682), and will be - removed after it gets fixed. The arguments and outputs are the same as `tf.nn.softmax`, and relies on the fact that - `softmax(x) = softmax(x + c)` (see https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html). - - Args: - logits (`tf.Tensor`): - Must be one of the following types: half, float32, float64. - axis (`int`, *optional*): - The dimension softmax would be performed on. The default is -1 which indicates the last dimension. - name (`str`, *optional*): - A name for the operation. - - Returns: - `tf.Tensor`: - A Tensor. Has the same type and shape as logits. - """ - # TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if - # it has the fix. After we drop the support for unfixed versions, remove this function. - return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name) - - -def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1): - # This is a very simplified functional layernorm, designed to duplicate - # the functionality of PyTorch nn.functional.layer_norm when this is needed to port - # models in Transformers. - - if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int): - raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.") - - # Get mean and variance on the axis to be normalized - mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True) - - if axis != -1: - # Reshape scale and weight to have the same rank as inputs, but with 1 dimensions - # on every dimension except axis - shape = [1] * inputs.shape.rank - shape[axis] = shape_list(inputs)[axis] - weight = tf.reshape(weight, shape) - bias = tf.reshape(bias, shape) - - # Compute layer normalization using the batch_normalization - # function. - outputs = tf.nn.batch_normalization( - inputs, - mean, - variance, - offset=bias, - scale=weight, - variance_epsilon=epsilon, - ) - return outputs - - -def scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale: Optional[float] = None -): - """TF equivalent for torch's nn.functional.scaled_dot_product_attention""" - if dropout_p != 0.0: - raise ValueError( - "Dropout is not supported in this implementation - file an issue " - "with Transformers and ping @Rocketknight1 if you need it for a port!" - ) - if is_causal and attn_mask is not None: - raise ValueError("You cannot specify an attn_mask and is_causal at the same time!") - if is_causal: - attn_mask = tf.ones((tf.shape(query)[-2], tf.shape(key)[-2]), dtype=tf.int32) - attn_mask = tf.experimental.numpy.tril(attn_mask, k=0) - if attn_mask is not None and (attn_mask.dtype.is_integer or attn_mask.dtype.is_bool): - # Convert boolean mask to a negative logit bias - attn_mask = tf.where(attn_mask > 0, tf.cast(0.0, query.dtype), tf.cast(-1000.0, query.dtype)) - logits = tf.einsum("...qd, ...kd -> ...qk", query, key) - if scale is None: - scale = tf.cast(tf.shape(key)[-1], logits.dtype) ** -0.5 - logits *= scale # scale by 1/sqrt(key_dim) - if attn_mask is not None: - logits += attn_mask - probs = tf.nn.softmax(logits) - return probs @ value - - -def flatten(input, start_dim=0, end_dim=-1): - # Replicates the behavior of torch.flatten in TF - - # If end_dim or start_dim is negative, count them from the end - if end_dim < 0: - end_dim += input.shape.rank - if start_dim < 0: - start_dim += input.shape.rank - - if start_dim == end_dim: - return input - - in_shape = tf.shape(input) - flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1]) - out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0) - return tf.reshape(input, out_shape) - - -def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor: - """ - Invert an attention mask (e.g., switches 0. and 1.). - - Args: - encoder_attention_mask (`torch.Tensor`): An attention mask. - - Returns: - `tf.Tensor`: The inverted attention mask. - """ - if not isinstance(encoder_attention_mask, tf.Tensor): - encoder_attention_mask = tf.convert_to_tensor(encoder_attention_mask) # Catches stray NumPy inputs - if encoder_attention_mask.shape.rank == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if encoder_attention_mask.shape.rank == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow - # /transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = (encoder_extended_attention_mask == - # encoder_extended_attention_mask.transpose(-1, -2)) - encoder_extended_attention_mask = ( - tf.cast(1, encoder_attention_mask.dtype) - encoder_extended_attention_mask - ) * encoder_extended_attention_mask.dtype.min - - return encoder_extended_attention_mask - - -def check_embeddings_within_bounds(tensor: tf.Tensor, embed_dim: int, tensor_name: str = "input_ids") -> None: - """ - `tf.gather`, on which TF embedding layers are based, won't check positive out of bound indices on GPU, returning - zeros instead. This function adds a check against that dangerous silent behavior. - - Args: - tensor (`tf.Tensor`): The tensor of indices to check. - embed_dim (`int`): The embedding dimension. - tensor_name (`str`, *optional*): The name of the tensor to use in the error message. - """ - tf.debugging.assert_less( - tensor, - tf.cast(embed_dim, dtype=tensor.dtype), - message=( - f"The maximum value of {tensor_name} ({tf.math.reduce_max(tensor)}) must be smaller than the embedding " - f"layer's input dimension ({embed_dim}). The likely cause is some problem at tokenization time." - ), - ) - - -def save_attributes_to_hdf5_group(group, name, data): - """Saves attributes (data) of the specified name into the HDF5 group. - - This method deals with an inherent problem of HDF5 file which is not able to store data larger than - HDF5_OBJECT_HEADER_LIMIT bytes. - - Args: - group: A pointer to a HDF5 group. - name: A name of the attributes to save. - data: Attributes data to store. - - Raises: - RuntimeError: If any single attribute is too large to be saved. - - Copied from Keras to Transformers to avoid versioning issues. - """ - HDF5_OBJECT_HEADER_LIMIT = 64512 - # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT` - # because in that case even chunking the array would not make the saving - # possible. - bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT] - - # Expecting this to never be true. - if bad_attributes: - raise RuntimeError( - "The following attributes cannot be saved to HDF5 file because " - f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} " - f"bytes: {bad_attributes}" - ) - - data_npy = np.asarray(data) - - num_chunks = 1 - chunked_data = np.array_split(data_npy, num_chunks) - - # This will never loop forever thanks to the test above. - while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data): - num_chunks += 1 - chunked_data = np.array_split(data_npy, num_chunks) - - if num_chunks > 1: - for chunk_id, chunk_data in enumerate(chunked_data): - group.attrs["%s%d" % (name, chunk_id)] = chunk_data - else: - group.attrs[name] = data - - -def load_attributes_from_hdf5_group(group, name): - """Loads attributes of the specified name from the HDF5 group. - - This method deals with an inherent problem of HDF5 file which is not able to store data larger than - HDF5_OBJECT_HEADER_LIMIT bytes. - - Args: - group: A pointer to a HDF5 group. - name: A name of the attributes to load. - - Returns: - data: Attributes data. - - Copied from Keras to Transformers to avoid versioning issues. - """ - if name in group.attrs: - data = [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs[name]] - else: - data = [] - chunk_id = 0 - while "%s%d" % (name, chunk_id) in group.attrs: - data.extend( - [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs["%s%d" % (name, chunk_id)]] - ) - chunk_id += 1 - return data - - -def expand_1d(data): - """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s. - Copied from Keras to here to avoid versioning issues.""" - - def _expand_single_1d_tensor(t): - if isinstance(t, tf.Tensor) and t.shape.rank == 1: - return tf.expand_dims(t, axis=-1) - return t - - return tf.nest.map_structure(_expand_single_1d_tensor, data) - - -def convert_batch_encoding(*args, **kwargs): - # Convert HF BatchEncoding/BatchFeature objects in the inputs to dicts that Keras understands - if args and isinstance(args[0], (BatchEncoding, BatchFeature)): - args = list(args) - args[0] = dict(args[0]) - elif "x" in kwargs and isinstance(kwargs["x"], (BatchEncoding, BatchFeature)): - kwargs["x"] = dict(kwargs["x"]) - return args, kwargs diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py deleted file mode 100644 index 24763dabf916..000000000000 --- a/src/transformers/training_args_tf.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass, field -from functools import cached_property -from typing import Optional - -from .training_args import TrainingArguments -from .utils import is_tf_available, logging, requires_backends - - -logger = logging.get_logger(__name__) - -if is_tf_available(): - import tensorflow as tf - - from .modeling_tf_utils import keras - - -@dataclass -class TFTrainingArguments(TrainingArguments): - """ - TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop - itself**. - - Using [`HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - output_dir (`str`): - The output directory where the model predictions and checkpoints will be written. - overwrite_output_dir (`bool`, *optional*, defaults to `False`): - If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir` - points to a checkpoint directory. - do_train (`bool`, *optional*, defaults to `False`): - Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used - by your training/evaluation scripts instead. See the [example - scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. - do_eval (`bool`, *optional*): - Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is - different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your - training/evaluation scripts instead. See the [example - scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. - do_predict (`bool`, *optional*, defaults to `False`): - Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's - intended to be used by your training/evaluation scripts instead. See the [example - scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. - eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`): - The evaluation strategy to adopt during training. Possible values are: - - - `"no"`: No evaluation is done during training. - - `"steps"`: Evaluation is done (and logged) every `eval_steps`. - - `"epoch"`: Evaluation is done at the end of each epoch. - - per_device_train_batch_size (`int`, *optional*, defaults to 8): - The batch size per GPU/TPU core/CPU for training. - per_device_eval_batch_size (`int`, *optional*, defaults to 8): - The batch size per GPU/TPU core/CPU for evaluation. - gradient_accumulation_steps (`int`, *optional*, defaults to 1): - Number of updates steps to accumulate the gradients for, before performing a backward/update pass. - - - - When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, - evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples. - - - - learning_rate (`float`, *optional*, defaults to 5e-5): - The initial learning rate for Adam. - weight_decay (`float`, *optional*, defaults to 0): - The weight decay to apply (if not zero). - adam_beta1 (`float`, *optional*, defaults to 0.9): - The beta1 hyperparameter for the Adam optimizer. - adam_beta2 (`float`, *optional*, defaults to 0.999): - The beta2 hyperparameter for the Adam optimizer. - adam_epsilon (`float`, *optional*, defaults to 1e-8): - The epsilon hyperparameter for the Adam optimizer. - max_grad_norm (`float`, *optional*, defaults to 1.0): - Maximum gradient norm (for gradient clipping). - num_train_epochs(`float`, *optional*, defaults to 3.0): - Total number of training epochs to perform. - max_steps (`int`, *optional*, defaults to -1): - If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`. - For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until - `max_steps` is reached. - warmup_ratio (`float`, *optional*, defaults to 0.0): - Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. - warmup_steps (`int`, *optional*, defaults to 0): - Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`. - logging_dir (`str`, *optional*): - [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to - *runs/**CURRENT_DATETIME_HOSTNAME***. - logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): - The logging strategy to adopt during training. Possible values are: - - - `"no"`: No logging is done during training. - - `"epoch"`: Logging is done at the end of each epoch. - - `"steps"`: Logging is done every `logging_steps`. - - logging_first_step (`bool`, *optional*, defaults to `False`): - Whether to log and evaluate the first `global_step` or not. - logging_steps (`int`, *optional*, defaults to 500): - Number of update steps between two logs if `logging_strategy="steps"`. - save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`): - The checkpoint save strategy to adopt during training. Possible values are: - - - `"no"`: No save is done during training. - - `"epoch"`: Save is done at the end of each epoch. - - `"steps"`: Save is done every `save_steps`. - - save_steps (`int`, *optional*, defaults to 500): - Number of updates steps before two checkpoint saves if `save_strategy="steps"`. - save_total_limit (`int`, *optional*): - If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in - `output_dir`. - no_cuda (`bool`, *optional*, defaults to `False`): - Whether to not use CUDA even when it is available or not. - seed (`int`, *optional*, defaults to 42): - Random seed that will be set at the beginning of training. - fp16 (`bool`, *optional*, defaults to `False`): - Whether to use 16-bit (mixed) precision training (through NVIDIA Apex) instead of 32-bit training. - fp16_opt_level (`str`, *optional*, defaults to 'O1'): - For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on - the [Apex documentation](https://nvidia.github.io/apex/amp). - local_rank (`int`, *optional*, defaults to -1): - During distributed training, the rank of the process. - tpu_num_cores (`int`, *optional*): - When training on TPU, the number of TPU cores (automatically passed by launcher script). - debug (`bool`, *optional*, defaults to `False`): - Whether to activate the trace to record computation graphs and profiling information or not. - dataloader_drop_last (`bool`, *optional*, defaults to `False`): - Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) - or not. - eval_steps (`int`, *optional*, defaults to 1000): - Number of update steps before two evaluations. - past_index (`int`, *optional*, defaults to -1): - Some models like [TransformerXL](../model_doc/transformerxl) or :doc*XLNet <../model_doc/xlnet>* can make - use of the past hidden states for their predictions. If this argument is set to a positive int, the - `Trainer` will use the corresponding output (usually index 2) as the past state and feed it to the model at - the next training step under the keyword argument `mems`. - tpu_name (`str`, *optional*): - The name of the TPU the process is running on. - tpu_zone (`str`, *optional*): - The zone of the TPU the process is running on. If not specified, we will attempt to automatically detect - from metadata. - gcp_project (`str`, *optional*): - Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to - automatically detect from metadata. - run_name (`str`, *optional*): - A descriptor for the run. Notably used for trackio, wandb, mlflow, comet and swanlab logging. - xla (`bool`, *optional*): - Whether to activate the XLA compilation or not. - """ - - framework = "tf" - tpu_name: Optional[str] = field( - default=None, - metadata={"help": "Name of TPU"}, - ) - - tpu_zone: Optional[str] = field( - default=None, - metadata={"help": "Zone of TPU"}, - ) - - gcp_project: Optional[str] = field( - default=None, - metadata={"help": "Name of Cloud TPU-enabled project"}, - ) - - poly_power: float = field( - default=1.0, - metadata={"help": "Power for the Polynomial decay LR scheduler."}, - ) - - xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"}) - - @cached_property - def _setup_strategy(self) -> tuple["tf.distribute.Strategy", int]: - requires_backends(self, ["tf"]) - logger.info("Tensorflow: setting up strategy") - - gpus = tf.config.list_physical_devices("GPU") - - # Set to float16 at first - if self.fp16: - keras.mixed_precision.set_global_policy("mixed_float16") - - if self.no_cuda: - strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0") - else: - try: - if self.tpu_name: - tpu = tf.distribute.cluster_resolver.TPUClusterResolver( - self.tpu_name, zone=self.tpu_zone, project=self.gcp_project - ) - else: - tpu = tf.distribute.cluster_resolver.TPUClusterResolver() - except ValueError: - if self.tpu_name: - raise RuntimeError(f"Couldn't connect to TPU {self.tpu_name}!") - else: - tpu = None - - if tpu: - # Set to bfloat16 in case of TPU - if self.fp16: - keras.mixed_precision.set_global_policy("mixed_bfloat16") - - tf.config.experimental_connect_to_cluster(tpu) - tf.tpu.experimental.initialize_tpu_system(tpu) - - strategy = tf.distribute.TPUStrategy(tpu) - - elif len(gpus) == 0: - strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0") - elif len(gpus) == 1: - strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") - elif len(gpus) > 1: - # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` - strategy = tf.distribute.MirroredStrategy() - else: - raise ValueError("Cannot find the proper strategy, please check your environment properties.") - - return strategy - - @property - def strategy(self) -> "tf.distribute.Strategy": - """ - The strategy used for distributed training. - """ - requires_backends(self, ["tf"]) - return self._setup_strategy - - @property - def n_replicas(self) -> int: - """ - The number of replicas (CPUs, GPUs or TPU cores) used in this training. - """ - requires_backends(self, ["tf"]) - return self._setup_strategy.num_replicas_in_sync - - @property - def should_log(self): - """ - Whether or not the current process should produce log. - """ - return False # TF Logging is handled by Keras not the Trainer - - @property - def train_batch_size(self) -> int: - """ - The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training). - """ - if self.per_gpu_train_batch_size: - logger.warning( - "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future " - "version. Using `--per_device_train_batch_size` is preferred." - ) - per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size - return per_device_batch_size * self.n_replicas - - @property - def eval_batch_size(self) -> int: - """ - The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training). - """ - if self.per_gpu_eval_batch_size: - logger.warning( - "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future " - "version. Using `--per_device_eval_batch_size` is preferred." - ) - per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size - return per_device_batch_size * self.n_replicas - - @property - def n_gpu(self) -> int: - """ - The number of replicas (CPUs, GPUs or TPU cores) used in this training. - """ - requires_backends(self, ["tf"]) - warnings.warn( - "The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.", - FutureWarning, - ) - return self._setup_strategy.num_replicas_in_sync diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py deleted file mode 100644 index 6f886de28246..000000000000 --- a/src/transformers/utils/dummy_flax_objects.py +++ /dev/null @@ -1,107 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class FlaxForcedBOSTokenLogitsProcessor(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxForceTokensLogitsProcessor(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxGenerationMixin(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxLogitsProcessor(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxLogitsProcessorList(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxLogitsWarper(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxMinLengthLogitsProcessor(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxSuppressTokensLogitsProcessor(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxTemperatureLogitsWarper(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxTopKLogitsWarper(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxTopPLogitsWarper(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxWhisperTimeStampLogitsProcessor(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxPreTrainedModel(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py deleted file mode 100644 index de7b6f505df5..000000000000 --- a/src/transformers/utils/dummy_tf_objects.py +++ /dev/null @@ -1,178 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class TFForcedBOSTokenLogitsProcessor(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFForceTokensLogitsProcessor(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFGenerationMixin(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFLogitsProcessor(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFLogitsProcessorList(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFLogitsWarper(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFMinLengthLogitsProcessor(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFNoBadWordsLogitsProcessor(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFNoRepeatNGramLogitsProcessor(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFSuppressTokensLogitsProcessor(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFTemperatureLogitsWarper(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFTopKLogitsWarper(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFTopPLogitsWarper(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class KerasMetricCallback(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class PushToHubCallback(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFPreTrainedModel(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFSequenceSummary(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFSharedEmbeddings(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -def shape_list(*args, **kwargs): - requires_backends(shape_list, ["tf"]) - - -class AdamWeightDecay(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class GradientAccumulator(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class WarmUp(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -def create_optimizer(*args, **kwargs): - requires_backends(create_optimizer, ["tf"]) diff --git a/tests/sagemaker/scripts/tensorflow/run_tf.py b/tests/sagemaker/scripts/tensorflow/run_tf.py deleted file mode 100644 index a5b8e3fe1f05..000000000000 --- a/tests/sagemaker/scripts/tensorflow/run_tf.py +++ /dev/null @@ -1,104 +0,0 @@ -import argparse -import logging -import sys -import time - -import tensorflow as tf -from datasets import load_dataset -from packaging.version import parse - -from transformers import AutoTokenizer, TFAutoModelForSequenceClassification - - -try: - import tf_keras as keras -except (ModuleNotFoundError, ImportError): - import keras - - if parse(keras.__version__).major > 2: - raise ValueError( - "Your currently installed version of Keras is Keras 3, but this is not yet supported in " - "Transformers. Please install the backwards-compatible tf-keras package with " - "`pip install tf-keras`." - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - # Hyperparameters sent by the client are passed as command-line arguments to the script. - parser.add_argument("--epochs", type=int, default=1) - parser.add_argument("--per_device_train_batch_size", type=int, default=16) - parser.add_argument("--per_device_eval_batch_size", type=int, default=8) - parser.add_argument("--model_name_or_path", type=str) - parser.add_argument("--learning_rate", type=str, default=5e-5) - parser.add_argument("--do_train", type=bool, default=True) - parser.add_argument("--do_eval", type=bool, default=True) - parser.add_argument("--output_dir", type=str) - - args, _ = parser.parse_known_args() - - # overwrite batch size until we have tf_glue.py - args.per_device_train_batch_size = 16 - args.per_device_eval_batch_size = 16 - - # Set up logging - logger = logging.getLogger(__name__) - - logging.basicConfig( - level=logging.getLevelName("INFO"), - handlers=[logging.StreamHandler(sys.stdout)], - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - # Load model and tokenizer - model = TFAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path) - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) - - # Load dataset - train_dataset, test_dataset = load_dataset("stanfordnlp/imdb", split=["train", "test"]) - train_dataset = train_dataset.shuffle().select(range(5000)) # smaller the size for train dataset to 5k - test_dataset = test_dataset.shuffle().select(range(500)) # smaller the size for test dataset to 500 - - # Preprocess train dataset - train_dataset = train_dataset.map( - lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True - ) - train_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"]) - - train_features = { - x: train_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length]) - for x in ["input_ids", "attention_mask"] - } - tf_train_dataset = tf.data.Dataset.from_tensor_slices((train_features, train_dataset["label"])).batch( - args.per_device_train_batch_size - ) - - # Preprocess test dataset - test_dataset = test_dataset.map( - lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True - ) - test_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"]) - - test_features = { - x: test_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length]) - for x in ["input_ids", "attention_mask"] - } - tf_test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_dataset["label"])).batch( - args.per_device_eval_batch_size - ) - - # fine optimizer and loss - optimizer = keras.optimizers.Adam(learning_rate=args.learning_rate) - loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) - metrics = [keras.metrics.SparseCategoricalAccuracy()] - model.compile(optimizer=optimizer, loss=loss, metrics=metrics) - - start_train_time = time.time() - train_results = model.fit(tf_train_dataset, epochs=args.epochs, batch_size=args.per_device_train_batch_size) - end_train_time = time.time() - start_train_time - - logger.info("*** Train ***") - logger.info(f"train_runtime = {end_train_time}") - for key, value in train_results.history.items(): - logger.info(f" {key} = {value}") diff --git a/utils/check_tf_ops.py b/utils/check_tf_ops.py deleted file mode 100644 index f6c2b8bae4e2..000000000000 --- a/utils/check_tf_ops.py +++ /dev/null @@ -1,101 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json -import os - -from tensorflow.core.protobuf.saved_model_pb2 import SavedModel - - -# All paths are set with the intent you should run this script from the root of the repo with the command -# python utils/check_copies.py -REPO_PATH = "." - -# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model) -INTERNAL_OPS = [ - "Assert", - "AssignVariableOp", - "EmptyTensorList", - "MergeV2Checkpoints", - "ReadVariableOp", - "ResourceGather", - "RestoreV2", - "SaveV2", - "ShardedFilename", - "StatefulPartitionedCall", - "StaticRegexFullMatch", - "VarHandleOp", -] - - -def onnx_compliancy(saved_model_path, strict, opset): - saved_model = SavedModel() - onnx_ops = [] - - with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f: - onnx_opsets = json.load(f)["opsets"] - - for i in range(1, opset + 1): - onnx_ops.extend(onnx_opsets[str(i)]) - - with open(saved_model_path, "rb") as f: - saved_model.ParseFromString(f.read()) - - model_op_names = set() - - # Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs) - for meta_graph in saved_model.meta_graphs: - # Add operations in the graph definition - model_op_names.update(node.op for node in meta_graph.graph_def.node) - - # Go through the functions in the graph definition - for func in meta_graph.graph_def.library.function: - # Add operations in each function - model_op_names.update(node.op for node in func.node_def) - - # Convert to list, sorted if you want - model_op_names = sorted(model_op_names) - incompatible_ops = [] - - for op in model_op_names: - if op not in onnx_ops and op not in INTERNAL_OPS: - incompatible_ops.append(op) - - if strict and len(incompatible_ops) > 0: - raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops) - elif len(incompatible_ops) > 0: - print(f"Found the following incompatible ops for the opset {opset}:") - print(*incompatible_ops, sep="\n") - else: - print(f"The saved model {saved_model_path} can properly be converted with ONNX.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).") - parser.add_argument( - "--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested." - ) - parser.add_argument( - "--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model." - ) - parser.add_argument( - "--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)" - ) - args = parser.parse_args() - - if args.framework == "onnx": - onnx_compliancy(args.saved_model_path, args.strict, args.opset) From 176fed4ae33b104b00f84c998392837c1d7bc51c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 9 Sep 2025 14:10:31 +0200 Subject: [PATCH 03/35] continue the purge --- conftest.py | 2 - .../multiple_choice/utils_multiple_choice.py | 90 +--- .../legacy/token-classification/utils_ner.py | 99 +--- src/transformers/__init__.py | 135 +---- src/transformers/commands/env.py | 33 -- src/transformers/commands/train.py | 158 ------ src/transformers/convert_graph_to_onnx.py | 79 +-- src/transformers/data/data_collator.py | 482 +----------------- src/transformers/data/processors/glue.py | 51 +- src/transformers/data/processors/squad.py | 113 +--- src/transformers/data/processors/utils.py | 23 +- src/transformers/feature_extraction_utils.py | 30 +- src/transformers/file_utils.py | 3 - src/transformers/generation/__init__.py | 124 ----- src/transformers/image_transforms.py | 46 +- src/transformers/image_utils.py | 14 +- src/transformers/keras_callbacks.py | 413 --------------- src/transformers/modelcard.py | 135 ----- src/transformers/modeling_utils.py | 10 +- .../models/albert/modeling_albert.py | 134 +---- .../models/altclip/modeling_altclip.py | 3 +- src/transformers/models/bert/modeling_bert.py | 81 +-- .../modeling_bert_generation.py | 90 +--- .../models/big_bird/modeling_big_bird.py | 169 +----- .../modeling_bigbird_pegasus.py | 2 - .../bridgetower/modeling_bridgetower.py | 3 +- .../models/camembert/modeling_camembert.py | 5 +- .../models/canine/modeling_canine.py | 105 +--- src/transformers/models/clap/modeling_clap.py | 3 +- .../models/codegen/tokenization_codegen.py | 11 +- .../codegen/tokenization_codegen_fast.py | 8 +- .../image_processing_conditional_detr.py | 17 +- .../models/convbert/modeling_convbert.py | 133 +---- .../models/data2vec/modeling_data2vec_text.py | 5 +- .../modeling_decision_transformer.py | 58 --- .../image_processing_deformable_detr.py | 19 +- .../deprecated/deta/image_processing_deta.py | 18 +- .../jukebox/tokenization_jukebox.py | 23 +- .../models/deprecated/nezha/modeling_nezha.py | 77 --- .../deprecated/qdqbert/modeling_qdqbert.py | 78 --- .../models/deprecated/realm/modeling_realm.py | 112 ---- .../deprecated/realm/retrieval_realm.py | 10 - .../retribert/modeling_retribert.py | 1 - .../modeling_trajectory_transformer.py | 76 --- .../transfo_xl/modeling_transfo_xl.py | 129 ----- .../models/detr/image_processing_detr.py | 19 +- .../models/distilbert/modeling_distilbert.py | 1 - src/transformers/models/dpr/modeling_dpr.py | 3 - .../models/electra/modeling_electra.py | 87 ---- src/transformers/models/esm/modeling_esm.py | 2 - .../models/flaubert/modeling_flaubert.py | 1 - .../models/funnel/modeling_funnel.py | 93 ---- src/transformers/models/gpt2/modeling_gpt2.py | 61 --- .../models/gpt_neo/modeling_gpt_neo.py | 87 ---- .../image_processing_grounding_dino.py | 19 +- .../models/idefics/processing_idefics.py | 75 +-- .../models/imagegpt/modeling_imagegpt.py | 113 ---- .../models/lxmert/modeling_lxmert.py | 83 --- .../models/markuplm/modeling_markuplm.py | 2 - .../megatron_bert/modeling_megatron_bert.py | 73 --- .../models/mobilebert/modeling_mobilebert.py | 83 --- .../mobilenet_v1/modeling_mobilenet_v1.py | 105 ---- .../mobilenet_v2/modeling_mobilenet_v2.py | 171 ------- src/transformers/models/mt5/modeling_mt5.py | 107 ---- .../models/owlv2/processing_owlv2.py | 17 +- .../models/owlvit/processing_owlvit.py | 18 +- .../models/rembert/modeling_rembert.py | 88 ---- .../models/roberta/modeling_roberta.py | 5 +- .../modeling_roberta_prelayernorm.py | 5 +- .../models/roc_bert/modeling_roc_bert.py | 79 --- .../models/roformer/modeling_roformer.py | 76 --- .../rt_detr/image_processing_rt_detr.py | 19 +- .../models/sam/image_processing_sam.py | 344 +------------ src/transformers/models/sam/processing_sam.py | 21 +- src/transformers/models/t5/modeling_t5.py | 121 ----- .../models/tapas/modeling_tapas.py | 140 ----- .../models/wav2vec2/tokenization_wav2vec2.py | 19 +- .../tokenization_wav2vec2_phoneme.py | 18 +- src/transformers/models/xlm/modeling_xlm.py | 1 - .../xlm_roberta/modeling_xlm_roberta.py | 5 +- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 2 - .../models/xlnet/modeling_xlnet.py | 154 ------ src/transformers/models/xmod/modeling_xmod.py | 5 +- .../models/yolos/image_processing_yolos.py | 19 +- src/transformers/pipelines/__init__.py | 275 +++------- src/transformers/pipelines/base.py | 338 +++--------- src/transformers/pipelines/fill_mask.py | 43 +- .../pipelines/question_answering.py | 33 +- .../pipelines/table_question_answering.py | 179 ++----- .../pipelines/text2text_generation.py | 38 +- src/transformers/pipelines/text_generation.py | 57 +-- .../pipelines/token_classification.py | 36 +- src/transformers/testing_utils.py | 44 +- src/transformers/tokenization_utils_base.py | 73 +-- src/transformers/trainer_utils.py | 17 +- src/transformers/utils/__init__.py | 4 - src/transformers/utils/doc.py | 458 +---------------- src/transformers/utils/generic.py | 157 +----- .../big_bird/test_tokenization_big_bird.py | 16 - tests/models/lxmert/test_modeling_lxmert.py | 30 +- utils/not_doctested.txt | 135 +---- 101 files changed, 395 insertions(+), 7289 deletions(-) delete mode 100644 src/transformers/commands/train.py delete mode 100644 src/transformers/keras_callbacks.py diff --git a/conftest.py b/conftest.py index 67064fbd5d3d..462a4b56de3d 100644 --- a/conftest.py +++ b/conftest.py @@ -67,8 +67,6 @@ "test_mismatched_shapes_have_properly_initialized_weights", "test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist", "test_model_is_small", - "test_tf_from_pt_safetensors", - "test_flax_from_pt_safetensors", "ModelTest::test_pipeline_", # None of the pipeline tests from PipelineTesterMixin (of which XxxModelTest inherits from) are running on device "ModelTester::test_pipeline_", "/repo_utils/", diff --git a/examples/legacy/multiple_choice/utils_multiple_choice.py b/examples/legacy/multiple_choice/utils_multiple_choice.py index cc07ffb2ef27..b62dabf76c56 100644 --- a/examples/legacy/multiple_choice/utils_multiple_choice.py +++ b/examples/legacy/multiple_choice/utils_multiple_choice.py @@ -26,7 +26,7 @@ import tqdm from filelock import FileLock -from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available +from transformers import PreTrainedTokenizer, is_torch_available logger = logging.getLogger(__name__) @@ -139,94 +139,6 @@ def __getitem__(self, i) -> InputFeatures: return self.features[i] -if is_tf_available(): - import tensorflow as tf - - class TFMultipleChoiceDataset: - """ - This will be superseded by a framework-agnostic approach - soon. - """ - - features: list[InputFeatures] - - def __init__( - self, - data_dir: str, - tokenizer: PreTrainedTokenizer, - task: str, - max_seq_length: Optional[int] = 128, - overwrite_cache=False, - mode: Split = Split.train, - ): - processor = processors[task]() - - logger.info(f"Creating features from dataset file at {data_dir}") - label_list = processor.get_labels() - if mode == Split.dev: - examples = processor.get_dev_examples(data_dir) - elif mode == Split.test: - examples = processor.get_test_examples(data_dir) - else: - examples = processor.get_train_examples(data_dir) - logger.info("Training examples: %s", len(examples)) - - self.features = convert_examples_to_features( - examples, - label_list, - max_seq_length, - tokenizer, - ) - - def gen(): - for ex_index, ex in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"): - if ex_index % 10000 == 0: - logger.info("Writing example %d of %d" % (ex_index, len(examples))) - - yield ( - { - "example_id": 0, - "input_ids": ex.input_ids, - "attention_mask": ex.attention_mask, - "token_type_ids": ex.token_type_ids, - }, - ex.label, - ) - - self.dataset = tf.data.Dataset.from_generator( - gen, - ( - { - "example_id": tf.int32, - "input_ids": tf.int32, - "attention_mask": tf.int32, - "token_type_ids": tf.int32, - }, - tf.int64, - ), - ( - { - "example_id": tf.TensorShape([]), - "input_ids": tf.TensorShape([None, None]), - "attention_mask": tf.TensorShape([None, None]), - "token_type_ids": tf.TensorShape([None, None]), - }, - tf.TensorShape([]), - ), - ) - - def get_dataset(self): - self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features))) - - return self.dataset - - def __len__(self): - return len(self.features) - - def __getitem__(self, i) -> InputFeatures: - return self.features[i] - - class DataProcessor: """Base class for data converters for multiple choice data sets.""" diff --git a/examples/legacy/token-classification/utils_ner.py b/examples/legacy/token-classification/utils_ner.py index 0c1725b59b4e..809c95c26fd0 100644 --- a/examples/legacy/token-classification/utils_ner.py +++ b/examples/legacy/token-classification/utils_ner.py @@ -22,7 +22,7 @@ from filelock import FileLock -from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available +from transformers import PreTrainedTokenizer, is_torch_available logger = logging.getLogger(__name__) @@ -271,100 +271,3 @@ def __len__(self): def __getitem__(self, i) -> InputFeatures: return self.features[i] - - -if is_tf_available(): - import tensorflow as tf - - class TFTokenClassificationDataset: - """ - This will be superseded by a framework-agnostic approach - soon. - """ - - features: list[InputFeatures] - pad_token_label_id: int = -100 - # Use cross entropy ignore_index as padding label id so that only - # real label ids contribute to the loss later. - - def __init__( - self, - token_classification_task: TokenClassificationTask, - data_dir: str, - tokenizer: PreTrainedTokenizer, - labels: list[str], - model_type: str, - max_seq_length: Optional[int] = None, - overwrite_cache=False, - mode: Split = Split.train, - ): - examples = token_classification_task.read_examples_from_file(data_dir, mode) - # TODO clean up all this to leverage built-in features of tokenizers - self.features = token_classification_task.convert_examples_to_features( - examples, - labels, - max_seq_length, - tokenizer, - cls_token_at_end=bool(model_type in ["xlnet"]), - # xlnet has a cls token at the end - cls_token=tokenizer.cls_token, - cls_token_segment_id=2 if model_type in ["xlnet"] else 0, - sep_token=tokenizer.sep_token, - sep_token_extra=False, - # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 - pad_on_left=bool(tokenizer.padding_side == "left"), - pad_token=tokenizer.pad_token_id, - pad_token_segment_id=tokenizer.pad_token_type_id, - pad_token_label_id=self.pad_token_label_id, - ) - - def gen(): - for ex in self.features: - if ex.token_type_ids is None: - yield ( - {"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, - ex.label_ids, - ) - else: - yield ( - { - "input_ids": ex.input_ids, - "attention_mask": ex.attention_mask, - "token_type_ids": ex.token_type_ids, - }, - ex.label_ids, - ) - - if "token_type_ids" not in tokenizer.model_input_names: - self.dataset = tf.data.Dataset.from_generator( - gen, - ({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64), - ( - {"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, - tf.TensorShape([None]), - ), - ) - else: - self.dataset = tf.data.Dataset.from_generator( - gen, - ({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64), - ( - { - "input_ids": tf.TensorShape([None]), - "attention_mask": tf.TensorShape([None]), - "token_type_ids": tf.TensorShape([None]), - }, - tf.TensorShape([None]), - ), - ) - - def get_dataset(self): - self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features))) - - return self.dataset - - def __len__(self): - return len(self.features) - - def __getitem__(self, i) -> InputFeatures: - return self.features[i] diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2cf1d5970b54..d52faa02c86c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -46,7 +46,6 @@ from .utils import is_sentencepiece_available as is_sentencepiece_available from .utils import is_speech_available as is_speech_available from .utils import is_tensorflow_text_available as is_tensorflow_text_available -from .utils import is_tf_available as is_tf_available from .utils import is_timm_available as is_timm_available from .utils import is_tokenizers_available as is_tokenizers_available from .utils import is_torch_available as is_torch_available @@ -66,7 +65,6 @@ "configuration_utils": ["PretrainedConfig"], "convert_graph_to_onnx": [], "convert_slow_tokenizers_checkpoints_to_fast": [], - "convert_tf_hub_seq_to_seq_bert_to_pytorch": [], "data": [ "DataProcessor", "InputExample", @@ -137,16 +135,6 @@ ], "loss": [], "modelcard": ["ModelCard"], - # Losses - "modeling_tf_pytorch_utils": [ - "convert_tf_weight_name_to_pt_weight_name", - "load_pytorch_checkpoint_in_tf2_model", - "load_pytorch_model_in_tf2_model", - "load_pytorch_weights_in_tf2_model", - "load_tf2_checkpoint_in_pytorch_model", - "load_tf2_model_in_pytorch_model", - "load_tf2_weights_in_pytorch_model", - ], # Models "onnx": [], "pipelines": [ @@ -218,7 +206,6 @@ ], "training_args": ["TrainingArguments"], "training_args_seq2seq": ["Seq2SeqTrainingArguments"], - "training_args_tf": ["TFTrainingArguments"], "utils": [ "CONFIG_NAME", "MODEL_CARD_NAME", @@ -252,7 +239,6 @@ "is_sklearn_available", "is_speech_available", "is_tensorflow_text_available", - "is_tf_available", "is_timm_available", "is_tokenizers_available", "is_torch_available", @@ -501,84 +487,6 @@ _import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"] _import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"] -# TensorFlow-backed objects -try: - if not is_tf_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils import dummy_tf_objects - - _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")] -else: - _import_structure["activations_tf"] = [] - _import_structure["generation"].extend( - [ - "TFForcedBOSTokenLogitsProcessor", - "TFForcedEOSTokenLogitsProcessor", - "TFForceTokensLogitsProcessor", - "TFGenerationMixin", - "TFLogitsProcessor", - "TFLogitsProcessorList", - "TFLogitsWarper", - "TFMinLengthLogitsProcessor", - "TFNoBadWordsLogitsProcessor", - "TFNoRepeatNGramLogitsProcessor", - "TFRepetitionPenaltyLogitsProcessor", - "TFSuppressTokensAtBeginLogitsProcessor", - "TFSuppressTokensLogitsProcessor", - "TFTemperatureLogitsWarper", - "TFTopKLogitsWarper", - "TFTopPLogitsWarper", - ] - ) - _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"] - _import_structure["modeling_tf_outputs"] = [] - _import_structure["modeling_tf_utils"] = [ - "TFPreTrainedModel", - "TFSequenceSummary", - "TFSharedEmbeddings", - "shape_list", - ] - _import_structure["optimization_tf"] = [ - "AdamWeightDecay", - "GradientAccumulator", - "WarmUp", - "create_optimizer", - ] - _import_structure["tf_utils"] = [] - - -# FLAX-backed objects -try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils import dummy_flax_objects - - _import_structure["utils.dummy_flax_objects"] = [ - name for name in dir(dummy_flax_objects) if not name.startswith("_") - ] -else: - _import_structure["generation"].extend( - [ - "FlaxForcedBOSTokenLogitsProcessor", - "FlaxForcedEOSTokenLogitsProcessor", - "FlaxForceTokensLogitsProcessor", - "FlaxGenerationMixin", - "FlaxLogitsProcessor", - "FlaxLogitsProcessorList", - "FlaxLogitsWarper", - "FlaxMinLengthLogitsProcessor", - "FlaxTemperatureLogitsWarper", - "FlaxSuppressTokensAtBeginLogitsProcessor", - "FlaxSuppressTokensLogitsProcessor", - "FlaxTopKLogitsWarper", - "FlaxTopPLogitsWarper", - "FlaxWhisperTimeStampLogitsProcessor", - ] - ) - _import_structure["modeling_flax_outputs"] = [] - _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] # Direct imports for type-checking if TYPE_CHECKING: @@ -716,22 +624,6 @@ from .generation import TemperatureLogitsWarper as TemperatureLogitsWarper from .generation import TextIteratorStreamer as TextIteratorStreamer from .generation import TextStreamer as TextStreamer - from .generation import TFForcedBOSTokenLogitsProcessor as TFForcedBOSTokenLogitsProcessor - from .generation import TFForcedEOSTokenLogitsProcessor as TFForcedEOSTokenLogitsProcessor - from .generation import TFForceTokensLogitsProcessor as TFForceTokensLogitsProcessor - from .generation import TFGenerationMixin as TFGenerationMixin - from .generation import TFLogitsProcessor as TFLogitsProcessor - from .generation import TFLogitsProcessorList as TFLogitsProcessorList - from .generation import TFLogitsWarper as TFLogitsWarper - from .generation import TFMinLengthLogitsProcessor as TFMinLengthLogitsProcessor - from .generation import TFNoBadWordsLogitsProcessor as TFNoBadWordsLogitsProcessor - from .generation import TFNoRepeatNGramLogitsProcessor as TFNoRepeatNGramLogitsProcessor - from .generation import TFRepetitionPenaltyLogitsProcessor as TFRepetitionPenaltyLogitsProcessor - from .generation import TFSuppressTokensAtBeginLogitsProcessor as TFSuppressTokensAtBeginLogitsProcessor - from .generation import TFSuppressTokensLogitsProcessor as TFSuppressTokensLogitsProcessor - from .generation import TFTemperatureLogitsWarper as TFTemperatureLogitsWarper - from .generation import TFTopKLogitsWarper as TFTopKLogitsWarper - from .generation import TFTopPLogitsWarper as TFTopPLogitsWarper from .generation import TopKLogitsWarper as TopKLogitsWarper from .generation import TopPLogitsWarper as TopPLogitsWarper from .generation import TypicalLogitsWarper as TypicalLogitsWarper @@ -775,20 +667,6 @@ from .modeling_rope_utils import ROPE_INIT_FUNCTIONS as ROPE_INIT_FUNCTIONS from .modeling_rope_utils import dynamic_rope_update as dynamic_rope_update - # TF 2.0 <=> PyTorch conversion utilities - from .modeling_tf_pytorch_utils import ( - convert_tf_weight_name_to_pt_weight_name as convert_tf_weight_name_to_pt_weight_name, - ) - from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model as load_pytorch_checkpoint_in_tf2_model - from .modeling_tf_pytorch_utils import load_pytorch_model_in_tf2_model as load_pytorch_model_in_tf2_model - from .modeling_tf_pytorch_utils import load_pytorch_weights_in_tf2_model as load_pytorch_weights_in_tf2_model - from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model as load_tf2_checkpoint_in_pytorch_model - from .modeling_tf_pytorch_utils import load_tf2_model_in_pytorch_model as load_tf2_model_in_pytorch_model - from .modeling_tf_pytorch_utils import load_tf2_weights_in_pytorch_model as load_tf2_weights_in_pytorch_model - from .modeling_tf_utils import TFPreTrainedModel as TFPreTrainedModel - from .modeling_tf_utils import TFSequenceSummary as TFSequenceSummary - from .modeling_tf_utils import TFSharedEmbeddings as TFSharedEmbeddings - from .modeling_tf_utils import shape_list as shape_list from .modeling_utils import AttentionInterface as AttentionInterface from .modeling_utils import PreTrainedModel as PreTrainedModel from .models import * @@ -815,12 +693,6 @@ from .optimization import get_scheduler as get_scheduler from .optimization import get_wsd_schedule as get_wsd_schedule - # Optimization - from .optimization_tf import AdamWeightDecay as AdamWeightDecay - from .optimization_tf import GradientAccumulator as GradientAccumulator - from .optimization_tf import WarmUp as WarmUp - from .optimization_tf import create_optimizer as create_optimizer - # Pipelines from .pipelines import AudioClassificationPipeline as AudioClassificationPipeline from .pipelines import AutomaticSpeechRecognitionPipeline as AutomaticSpeechRecognitionPipeline @@ -894,7 +766,6 @@ from .trainer_utils import set_seed as set_seed from .training_args import TrainingArguments as TrainingArguments from .training_args_seq2seq import Seq2SeqTrainingArguments as Seq2SeqTrainingArguments - from .training_args_tf import TFTrainingArguments as TFTrainingArguments # Files and general utilities from .utils import CONFIG_NAME as CONFIG_NAME @@ -968,9 +839,7 @@ ) -if not is_tf_available() and not is_torch_available() and not is_flax_available(): +if not is_torch_available(): logger.warning_advice( - "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. " - "Models won't be available and only tokenizers, configuration " - "and file/data utilities can be used." + "PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used." ) diff --git a/src/transformers/commands/env.py b/src/transformers/commands/env.py index 983a858cd952..9ef31c71a0d1 100644 --- a/src/transformers/commands/env.py +++ b/src/transformers/commands/env.py @@ -26,9 +26,7 @@ from ..integrations.deepspeed import is_deepspeed_available from ..utils import ( is_accelerate_available, - is_flax_available, is_safetensors_available, - is_tf_available, is_torch_available, is_torch_hpu_available, is_torch_npu_available, @@ -109,19 +107,6 @@ def run(self): elif pt_hpu_available: pt_accelerator = "HPU" - tf_version = "not installed" - tf_cuda_available = "NA" - if is_tf_available(): - import tensorflow as tf - - tf_version = tf.__version__ - try: - # deprecated in v2.1 - tf_cuda_available = tf.test.is_gpu_available() - except AttributeError: - # returns list of devices, convert to bool - tf_cuda_available = bool(tf.config.list_physical_devices("GPU")) - deepspeed_version = "not installed" if is_deepspeed_available(): # Redirect command line output to silence deepspeed import output. @@ -129,20 +114,6 @@ def run(self): import deepspeed deepspeed_version = deepspeed.__version__ - flax_version = "not installed" - jax_version = "not installed" - jaxlib_version = "not installed" - jax_backend = "NA" - if is_flax_available(): - import flax - import jax - import jaxlib - - flax_version = flax.__version__ - jax_version = jax.__version__ - jaxlib_version = jaxlib.__version__ - jax_backend = jax.lib.xla_bridge.get_backend().platform - info = { "`transformers` version": version, "Platform": platform.platform(), @@ -153,10 +124,6 @@ def run(self): "Accelerate config": f"{accelerate_config_str}", "DeepSpeed version": f"{deepspeed_version}", "PyTorch version (accelerator?)": f"{pt_version} ({pt_accelerator})", - "Tensorflow version (GPU?)": f"{tf_version} ({tf_cuda_available})", - "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})", - "Jax version": f"{jax_version}", - "JaxLib version": f"{jaxlib_version}", "Using distributed or parallel set-up in script?": "", } if is_torch_available(): diff --git a/src/transformers/commands/train.py b/src/transformers/commands/train.py deleted file mode 100644 index 06e95443df24..000000000000 --- a/src/transformers/commands/train.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from argparse import ArgumentParser, Namespace - -from ..data import SingleSentenceClassificationProcessor as Processor -from ..pipelines import TextClassificationPipeline -from ..utils import is_tf_available, is_torch_available, logging -from . import BaseTransformersCLICommand - - -if not is_tf_available() and not is_torch_available(): - raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training") - -# TF training parameters -USE_XLA = False -USE_AMP = False - - -def train_command_factory(args: Namespace): - """ - Factory function used to instantiate training command from provided command line arguments. - - Returns: TrainCommand - """ - return TrainCommand(args) - - -class TrainCommand(BaseTransformersCLICommand): - @staticmethod - def register_subcommand(parser: ArgumentParser): - """ - Register this command to argparse so it's available for the transformer-cli - - Args: - parser: Root parser to register command-specific arguments - """ - train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.") - - train_parser.add_argument( - "--train_data", - type=str, - required=True, - help="path to train (and optionally evaluation) dataset as a csv with tab separated labels and sentences.", - ) - train_parser.add_argument( - "--column_label", type=int, default=0, help="Column of the dataset csv file with example labels." - ) - train_parser.add_argument( - "--column_text", type=int, default=1, help="Column of the dataset csv file with example texts." - ) - train_parser.add_argument( - "--column_id", type=int, default=2, help="Column of the dataset csv file with example ids." - ) - train_parser.add_argument( - "--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)." - ) - - train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.") - train_parser.add_argument( - "--validation_split", - type=float, - default=0.1, - help="if validation dataset is not provided, fraction of train dataset to use as validation dataset.", - ) - - train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.") - - train_parser.add_argument( - "--task", type=str, default="text_classification", help="Task to train the model on." - ) - train_parser.add_argument( - "--model", type=str, default="google-bert/bert-base-uncased", help="Model's name or path to stored model." - ) - train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.") - train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.") - train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.") - train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.") - train_parser.set_defaults(func=train_command_factory) - - def __init__(self, args: Namespace): - self.logger = logging.get_logger("transformers/training") - - self.framework = "tf" if is_tf_available() else "torch" - - os.makedirs(args.output, exist_ok=True) - self.output = args.output - - self.column_label = args.column_label - self.column_text = args.column_text - self.column_id = args.column_id - - self.logger.info(f"Loading {args.task} pipeline for {args.model}") - if args.task == "text_classification": - self.pipeline = TextClassificationPipeline.from_pretrained(args.model) - elif args.task == "token_classification": - raise NotImplementedError - elif args.task == "question_answering": - raise NotImplementedError - - self.logger.info(f"Loading dataset from {args.train_data}") - self.train_dataset = Processor.create_from_csv( - args.train_data, - column_label=args.column_label, - column_text=args.column_text, - column_id=args.column_id, - skip_first_row=args.skip_first_row, - ) - self.valid_dataset = None - if args.validation_data: - self.logger.info(f"Loading validation dataset from {args.validation_data}") - self.valid_dataset = Processor.create_from_csv( - args.validation_data, - column_label=args.column_label, - column_text=args.column_text, - column_id=args.column_id, - skip_first_row=args.skip_first_row, - ) - - self.validation_split = args.validation_split - self.train_batch_size = args.train_batch_size - self.valid_batch_size = args.valid_batch_size - self.learning_rate = args.learning_rate - self.adam_epsilon = args.adam_epsilon - - def run(self): - if self.framework == "tf": - return self.run_tf() - return self.run_torch() - - def run_torch(self): - raise NotImplementedError - - def run_tf(self): - self.pipeline.fit( - self.train_dataset, - validation_data=self.valid_dataset, - validation_split=self.validation_split, - learning_rate=self.learning_rate, - adam_epsilon=self.adam_epsilon, - train_batch_size=self.train_batch_size, - valid_batch_size=self.valid_batch_size, - ) - - # Save trained pipeline - self.pipeline.save_pretrained(self.output) diff --git a/src/transformers/convert_graph_to_onnx.py b/src/transformers/convert_graph_to_onnx.py index 922ece8c0f45..1d7109d333d3 100644 --- a/src/transformers/convert_graph_to_onnx.py +++ b/src/transformers/convert_graph_to_onnx.py @@ -22,7 +22,7 @@ from transformers.pipelines import Pipeline, pipeline from transformers.tokenization_utils import BatchEncoding -from transformers.utils import ModelOutput, is_tf_available, is_torch_available +from transformers.utils import ModelOutput, is_torch_available # This is the minimal required version to @@ -67,7 +67,8 @@ def __init__(self): self.add_argument( "--framework", type=str, - choices=["pt", "tf"], + choices=["pt"], + default="pt", help="Framework for loading the model", ) self.add_argument("--opset", type=int, default=11, help="ONNX opset to use") @@ -165,7 +166,7 @@ def infer_shapes(nlp: Pipeline, framework: str) -> tuple[list[str], list[str], d Args: nlp: The pipeline object holding the model to be exported - framework: The framework identifier to dispatch to the correct inference scheme (pt/tf) + framework: Not used anymore, only kept for BC Returns: @@ -194,9 +195,9 @@ def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int): print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}") return axes - tokens = nlp.tokenizer("This is a sample output", return_tensors=framework) + tokens = nlp.tokenizer("This is a sample output", return_tensors="pt") seq_len = tokens.input_ids.shape[-1] - outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens) + outputs = nlp.model(**tokens) if isinstance(outputs, ModelOutput): outputs = outputs.to_tuple() if not isinstance(outputs, (list, tuple)): @@ -231,7 +232,7 @@ def load_graph_from_args( Args: pipeline_name: The kind of pipeline to use (ner, question-answering, etc.) - framework: The actual model to convert the pipeline from ("pt" or "tf") + framework: Not used anymore, only kept for BC model: The model name which will be loaded by the pipeline tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value @@ -242,16 +243,14 @@ def load_graph_from_args( if tokenizer is None: tokenizer = model - # Check the wanted framework is available - if framework == "pt" and not is_torch_available(): + # Check pytorch is available + if not is_torch_available(): raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.") - if framework == "tf" and not is_tf_available(): - raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.") print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})") # Allocate tokenizer and model - return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs) + return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework="pt", model_kwargs=models_kwargs) def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool): @@ -291,46 +290,6 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format ) -def convert_tensorflow(nlp: Pipeline, opset: int, output: Path): - """ - Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR) - - Args: - nlp: The pipeline to be exported - opset: The actual version of the ONNX operator set to use - output: Path where will be stored the generated ONNX model - - Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow - - """ - if not is_tf_available(): - raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.") - - print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\") - - try: - import tensorflow as tf - import tf2onnx - from tf2onnx import __version__ as t2ov - - print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}") - - # Build - input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf") - - # Forward - nlp.model.predict(tokens.data) - input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in tokens.items()] - model_proto, _ = tf2onnx.convert.from_keras( - nlp.model, input_signature, opset=opset, output_path=output.as_posix() - ) - - except ImportError as e: - raise Exception( - f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first. {e}" - ) - - def convert( framework: str, model: str, @@ -345,7 +304,7 @@ def convert( Convert the pipeline object to the ONNX Intermediate Representation (IR) format Args: - framework: The framework the pipeline is backed by ("pt" or "tf") + framework: Not used anymore, only kept for BC model: The name of the model to load for the pipeline output: The path where the ONNX graph will be stored opset: The actual version of the ONNX operator set to use @@ -366,7 +325,7 @@ def convert( print(f"ONNX opset version set to: {opset}") # Load the pipeline - nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs) + nlp = load_graph_from_args(pipeline_name, "pt", model, tokenizer, **model_kwargs) if not output.parent.exists(): print(f"Creating folder {output.parent}") @@ -375,10 +334,7 @@ def convert( raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion") # Export the graph - if framework == "pt": - convert_pytorch(nlp, opset, output, use_external_format) - else: - convert_tensorflow(nlp, opset, output) + convert_pytorch(nlp, opset, output, use_external_format) def optimize(onnx_model_path: Path) -> Path: @@ -518,15 +474,6 @@ def verify(path: Path): # Ensure requirements for quantization on onnxruntime is met check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION) - # onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch - if args.framework == "tf": - print( - "\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n" - "\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n" - "\t For more information, please refer to the onnxruntime documentation:\n" - "\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n" - ) - print("\n====== Optimizing ONNX model ======") # Quantization works best when using the optimized version of the model diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 10ee10e01950..d3bf1c6cfdcb 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -31,7 +31,7 @@ """ A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary -of PyTorch/TensorFlow tensors or NumPy arrays. +of PyTorch tensors or NumPy arrays. """ DataCollator = NewType("DataCollator", Callable[[list[InputDataClass]], dict[str, Any]]) @@ -40,8 +40,6 @@ class DataCollatorMixin: def __call__(self, features, return_tensors=None): if return_tensors is None: return_tensors = self.return_tensors - if return_tensors == "tf": - return self.tf_call(features) elif return_tensors == "pt": return self.torch_call(features) elif return_tensors == "np": @@ -91,8 +89,6 @@ def default_data_collator(features: list[InputDataClass], return_tensors="pt") - if return_tensors == "pt": return torch_default_data_collator(features) - elif return_tensors == "tf": - return tf_default_data_collator(features) elif return_tensors == "np": return numpy_default_data_collator(features) @@ -114,7 +110,7 @@ class DefaultDataCollator(DataCollatorMixin): Args: return_tensors (`str`, *optional*, defaults to `"pt"`): - The type of Tensor to return. Allowable values are "np", "pt" and "tf". + The type of Tensor to return. Allowable values are "np", or "pt". """ return_tensors: str = "pt" @@ -161,47 +157,6 @@ def torch_default_data_collator(features: list[InputDataClass]) -> dict[str, Any return batch -def tf_default_data_collator(features: list[InputDataClass]) -> dict[str, Any]: - import tensorflow as tf - - if not isinstance(features[0], Mapping): - features = [vars(f) for f in features] - first = features[0] - batch = {} - - # Special handling for labels. - # Ensure that tensor is created with the correct type - # (it should be automatically the case, but let's make sure of it.) - if "label" in first and first["label"] is not None: - label_col_name = "label" - elif "label_ids" in first and first["label_ids"] is not None: - label_col_name = "label_ids" - elif "labels" in first and first["labels"] is not None: - label_col_name = "labels" - else: - label_col_name = None - if label_col_name is not None: - if isinstance(first[label_col_name], tf.Tensor): - dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32 - elif isinstance(first[label_col_name], (np.ndarray, np.generic)): - dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32 - elif isinstance(first[label_col_name], (tuple, list)): - dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32 - else: - dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32 - batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype) - # Handling of all other possible keys. - # Again, we will use the first element to figure out which key/values are not None for this model. - for k, v in first.items(): - if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str): - if isinstance(v, (tf.Tensor, np.ndarray)): - batch[k] = tf.stack([f[k] for f in features]) - else: - batch[k] = tf.convert_to_tensor([f[k] for f in features]) - - return batch - - def numpy_default_data_collator(features: list[InputDataClass]) -> dict[str, Any]: if not isinstance(features[0], Mapping): features = [vars(f) for f in features] @@ -259,7 +214,7 @@ class DataCollatorWithPadding: This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.0 (Volta). return_tensors (`str`, *optional*, defaults to `"pt"`): - The type of Tensor to return. Allowable values are "np", "pt" and "tf". + The type of Tensor to return. Allowable values are "np", or "pt". """ tokenizer: PreTrainedTokenizerBase @@ -313,7 +268,7 @@ class DataCollatorForTokenClassification(DataCollatorMixin): label_pad_token_id (`int`, *optional*, defaults to -100): The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). return_tensors (`str`, *optional*, defaults to `"pt"`): - The type of Tensor to return. Allowable values are "np", "pt" and "tf". + The type of Tensor to return. Allowable values are "np", or "pt". """ tokenizer: PreTrainedTokenizerBase @@ -363,38 +318,6 @@ def to_list(tensor_or_iterable): batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64) return batch - def tf_call(self, features): - import tensorflow as tf - - label_name = "label" if "label" in features[0] else "labels" - labels = [feature[label_name] for feature in features] if label_name in features[0] else None - batch = pad_without_fast_tokenizer_warning( - self.tokenizer, - features, - padding=self.padding, - max_length=self.max_length, - pad_to_multiple_of=self.pad_to_multiple_of, - # Conversion to tensors will fail if we have labels as they are not of the same length yet. - return_tensors="tf" if labels is None else None, - ) - - if labels is None: - return batch - - sequence_length = tf.convert_to_tensor(batch["input_ids"]).shape[1] - padding_side = self.tokenizer.padding_side - if padding_side == "right": - batch["labels"] = [ - list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels - ] - else: - batch["labels"] = [ - [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels - ] - - batch = {k: tf.convert_to_tensor(v, dtype=tf.int64) for k, v in batch.items()} - return batch - def numpy_call(self, features): label_name = "label" if "label" in features[0] else "labels" labels = [feature[label_name] for feature in features] if label_name in features[0] else None @@ -463,44 +386,6 @@ def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] return result -def _tf_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): - import tensorflow as tf - - """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" - # Tensorize if necessary. - if isinstance(examples[0], (list, tuple)): - examples = [tf.convert_to_tensor(e, dtype=tf.int64) for e in examples] - - # Check if padding is necessary. - length_of_first = len(examples[0]) - are_tensors_same_length = all(len(x) == length_of_first for x in examples) - if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): - return tf.stack(examples, axis=0) - - # If yes, check if we have a `pad_token`. - if tokenizer.pad_token is None: - raise ValueError( - "You are attempting to pad samples but the tokenizer you are using" - f" ({tokenizer.__class__.__name__}) does not have a pad token." - ) - - # Creating the full tensor and filling it with our data. - max_length = max(len(x) for x in examples) - if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - # result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) - result = [] - rank = tf.rank(examples[0]) - paddings = np.zeros((rank, 2), dtype=np.int32) - for example in examples: - if tokenizer.padding_side == "right": - paddings[0, 1] = max_length - len(example) - else: - paddings[0, 0] = max_length - len(example) - result.append(tf.pad(example, paddings, constant_values=tokenizer.pad_token_id)) - return tf.stack(result, axis=0) - - def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" # Tensorize if necessary. @@ -560,7 +445,7 @@ class DataCollatorForMultipleChoice(DataCollatorMixin): This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). return_tensors (`str`, *optional*, defaults to `"pt"`): - The type of Tensor to return. Allowable values are "np", "pt" and "tf". + The type of Tensor to return. Allowable values are "np", or "pt". """ tokenizer: PreTrainedTokenizerBase @@ -599,30 +484,6 @@ def torch_call(self, examples: list[dict[str, Any]]): # Refactored implementati batch["labels"] = torch.tensor(labels, dtype=torch.int64) return batch - def tf_call(self, features): # Implementation taken from the docs. - import tensorflow as tf - - label_name = "label" if "label" in features[0] else "labels" - labels = [feature.pop(label_name) for feature in features] - batch_size = len(features) - num_choices = len(features[0]["input_ids"]) - flattened_features = [ - [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features - ] - flattened_features = sum(flattened_features, []) # Sometimes written as list(chain(*flattened_features)) - - batch = self.tokenizer.pad( - flattened_features, - padding=self.padding, - max_length=self.max_length, - pad_to_multiple_of=self.pad_to_multiple_of, - return_tensors="tf", - ) - - batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()} - batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64) - return batch - @dataclass class DataCollatorForSeq2Seq: @@ -656,7 +517,7 @@ class DataCollatorForSeq2Seq: label_pad_token_id (`int`, *optional*, defaults to -100): The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). return_tensors (`str`, *optional*, defaults to `"pt"`): - The type of Tensor to return. Allowable values are "np", "pt" and "tf". + The type of Tensor to return. Allowable values are "np", or "pt". """ tokenizer: PreTrainedTokenizerBase @@ -739,10 +600,6 @@ def __call__(self, features, return_tensors=None): import torch batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64) - elif return_tensors == "tf": - import tensorflow as tf - - batch["labels"] = tf.constant(batch["labels"], dtype=tf.int64) else: batch["labels"] = np.array(batch["labels"], dtype=np.int64) else: @@ -787,7 +644,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): pad_to_multiple_of (`int`, *optional*): If set, will pad the sequence to a multiple of the provided value. return_tensors (`str`): - The type of Tensor to return. Allowable values are "np", "pt" and "tf". + The type of Tensor to return. Allowable values are "np", or "pt". seed (`int`, *optional*): The seed to use for the random number generator for masking. If not provided, the global RNG will be used. @@ -852,11 +709,6 @@ def __post_init__(self): self.mask_replace_prob = float(self.mask_replace_prob) self.random_replace_prob = float(self.random_replace_prob) - if self.tf_experimental_compile: - import tensorflow as tf - - self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True) - self.generator = None def get_generator(self, seed): @@ -864,10 +716,6 @@ def get_generator(self, seed): import torch return torch.Generator().manual_seed(seed) - elif self.return_tensors == "tf": - import tensorflow as tf - - return tf.random.Generator.from_seed(seed) else: import numpy as np @@ -897,111 +745,6 @@ def create_rng(self): self.generator = self.get_generator(self.seed + worker_info.id) - @staticmethod - def tf_bernoulli(shape, probability, generator=None): - import tensorflow as tf - - prob_matrix = tf.fill(shape, probability) - # if generator exists, use it to generate the random numbers - # otherwise, use the global RNG - if generator: - return tf.cast(prob_matrix - generator.uniform(shape, 0, 1) >= 0, tf.bool) - else: - return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool) - - def tf_mask_tokens( - self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None - ) -> tuple[Any, Any]: - """ - Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. - """ - import tensorflow as tf - - mask_token_id = tf.cast(mask_token_id, inputs.dtype) - - input_shape = tf.shape(inputs) - # 1 for a special token, 0 for a normal token in the special tokens mask - # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) - masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability, self.generator) & ~special_tokens_mask - # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens - labels = tf.where(masked_indices, inputs, -100) - - # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob, self.generator) & masked_indices - - inputs = tf.where(indices_replaced, mask_token_id, inputs) - - if self.mask_replace_prob == 1 or self.random_replace_prob == 0: - return inputs, labels - - remaining_prob = 1 - self.mask_replace_prob - # scaling the random_replace_prob to the remaining probability for example if - # mask_replace_prob = 0.8 and random_replace_prob = 0.1, - # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5 - random_replace_prob_scaled = self.random_replace_prob / remaining_prob - # random_replace_prob% of the time, we replace masked input tokens with random word - indices_random = ( - self.tf_bernoulli(input_shape, random_replace_prob_scaled, self.generator) - & masked_indices - & ~indices_replaced - ) - - if self.generator: - random_words = self.generator.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype) - else: - random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype) - - inputs = tf.where(indices_random, random_words, inputs) - - # The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged - return inputs, labels - - def tf_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: - import tensorflow as tf - - if self.seed and self.generator is None: - # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator. - # If no seed supplied, we will use the global RNG - self.create_rng() - - # Handle dict or lists with proper padding and conversion to tensor. - if isinstance(examples[0], Mapping): - batch = pad_without_fast_tokenizer_warning( - self.tokenizer, examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of - ) - else: - batch = { - "input_ids": _tf_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) - } - - # If special token mask has been preprocessed, pop it from the dict. - special_tokens_mask = batch.pop("special_tokens_mask", None) - if self.mlm: - if special_tokens_mask is None: - special_tokens_mask = [ - self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) - for val in batch["input_ids"].numpy().tolist() - ] - # Cannot directly create as bool - special_tokens_mask = tf.cast(tf.convert_to_tensor(special_tokens_mask, dtype=tf.int64), tf.bool) - else: - special_tokens_mask = tf.cast(special_tokens_mask, tf.bool) - batch["input_ids"], batch["labels"] = self.tf_mask_tokens( - tf.cast(batch["input_ids"], tf.int64), - special_tokens_mask=special_tokens_mask, - mask_token_id=self.tokenizer.mask_token_id, - vocab_size=len(self.tokenizer), - ) - else: - labels = batch["input_ids"] - if self.tokenizer.pad_token_id is not None: - # Replace self.tokenizer.pad_token_id with -100 - labels = tf.where(labels == self.tokenizer.pad_token_id, -100, labels) - else: - labels = tf.identity(labels) # Makes a copy, just in case - batch["labels"] = labels - return batch - def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: # Handle dict or lists with proper padding and conversion to tensor. @@ -1226,41 +969,6 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d inputs, labels = self.torch_mask_tokens(batch_input, batch_mask) return {"input_ids": inputs, "labels": labels} - def tf_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: - import tensorflow as tf - - if self.seed and self.generator is None: - # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator. - # If no seed supplied, we will use the global RNG - self.create_rng() - - if isinstance(examples[0], Mapping): - input_ids = [e["input_ids"] for e in examples] - else: - input_ids = examples - examples = [{"input_ids": e} for e in examples] - - batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) - - mask_labels = [] - for e in examples: - ref_tokens = [] - for id in tolist(e["input_ids"]): - token = self.tokenizer._convert_id_to_token(id) - ref_tokens.append(token) - - # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] - if "chinese_ref" in e: - ref_pos = tolist(e["chinese_ref"]) - len_seq = len(e["input_ids"]) - for i in range(len_seq): - if i in ref_pos: - ref_tokens[i] = "##" + ref_tokens[i] - mask_labels.append(self._whole_word_mask(ref_tokens)) - batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) - inputs, labels = self.tf_mask_tokens(tf.cast(batch_input, tf.int64), batch_mask) - return {"input_ids": inputs, "labels": labels} - def numpy_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: if self.seed and self.generator is None: # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator. @@ -1307,13 +1015,6 @@ def _shuffle(self, cand_indexes): indices = torch.randperm(len(cand_indexes), generator=self.generator) return [cand_indexes[i] for i in indices] - elif self.return_tensors == "tf": - import tensorflow as tf - - seed = self.generator.make_seeds(2)[0] - indices = tf.random.experimental.stateless_shuffle(tf.range(len(cand_indexes)), seed=seed).numpy().tolist() - return [cand_indexes[i] for i in indices] - elif self.return_tensors == "np": self.generator.shuffle(cand_indexes) return cand_indexes @@ -1414,66 +1115,6 @@ def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> tuple[Any, Any]: # The rest of the time ((1-random_replacement_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged return inputs, labels - def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> tuple[Any, Any]: - """ - Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set - 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. - """ - import tensorflow as tf - - input_shape = tf.shape(inputs) - if self.tokenizer.mask_token is None: - raise ValueError( - "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the" - " --mlm flag if you want to use this tokenizer." - ) - labels = tf.identity(inputs) - # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) - - masked_indices = tf.cast(mask_labels, tf.bool) - - special_tokens_mask = [ - self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels - ] - masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool) - if self.tokenizer.pad_token is not None: - padding_mask = inputs == self.tokenizer.pad_token_id - masked_indices = masked_indices & ~padding_mask - - # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens - labels = tf.where(masked_indices, inputs, -100) - - # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob, self.generator) & masked_indices - - inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs) - - if self.mask_replace_prob == 1 or self.random_replace_prob == 0: - return inputs, labels - - remaining_prob = 1 - self.mask_replace_prob - # scaling the random_replace_prob to the remaining probability for example if - # mask_replace_prob = 0.8 and random_replace_prob = 0.1, - # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5 - random_replace_prob_scaled = self.random_replace_prob / remaining_prob - - # random_replace_prob% of the time, we replace masked input tokens with random word - indices_random = ( - self.tf_bernoulli(input_shape, random_replace_prob_scaled, self.generator) - & masked_indices - & ~indices_replaced - ) - - if self.generator: - random_words = self.generator.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64) - else: - random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64) - - inputs = tf.where(indices_random, random_words, inputs) - - # The rest of the time ((1-mask_replace_prob-random_replace_prob)% of the time) we keep the masked input tokens unchanged - return inputs, labels - def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> tuple[Any, Any]: """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set @@ -1543,7 +1184,7 @@ def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> tuple[Any, Any]: def tolist(x): if isinstance(x, list): return x - elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import + elif hasattr(x, "numpy"): x = x.numpy() return x.tolist() @@ -1765,113 +1406,6 @@ def torch_mask_tokens(self, inputs: Any) -> tuple[Any, Any, Any, Any]: return inputs.long(), perm_mask, target_mapping, labels.long() - def tf_mask_tokens(self, inputs: Any) -> tuple[Any, Any, Any, Any]: - """ - The masked tokens to be predicted for a particular sequence are determined by the following algorithm: - - 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far). - 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked) - 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be - masked - 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - - span_length]` and mask tokens `start_index:start_index + span_length` - 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the - sequence to be processed), repeat from Step 1. - """ - import tensorflow as tf - - if self.tokenizer.mask_token is None: - raise ValueError( - "This tokenizer does not have a mask token which is necessary for permutation language modeling." - " Please add a mask token if you want to use this tokenizer." - ) - - if tf.shape(inputs)[1] % 2 != 0: - raise ValueError( - "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see" - " relevant comments in source code for details." - ) - - labels = tf.identity(inputs) - # Creating the mask and target_mapping tensors - masked_indices = np.full(labels.shape.as_list(), 0, dtype=bool) - labels_shape = tf.shape(labels) - target_mapping = np.zeros((labels_shape[0], labels_shape[1], labels_shape[1]), dtype=np.float32) - - for i in range(len(labels)): - # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far). - cur_len = 0 - max_len = tf.shape(labels)[1] - - while cur_len < max_len: - # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked) - span_length = randint(1, self.max_span_length + 1) - # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked - context_length = int(span_length / self.plm_probability) - # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length` - start_index = cur_len + randint(0, context_length - span_length + 1) - masked_indices[i, start_index : start_index + span_length] = 1 - # Set `cur_len = cur_len + context_length` - cur_len += context_length - - # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether, - # the i-th predict corresponds to the i-th token. - target_mapping[i] = np.eye(labels_shape[1]) - masked_indices = tf.cast(tf.convert_to_tensor(masked_indices), dtype=tf.bool) - target_mapping = tf.convert_to_tensor(target_mapping) - special_tokens_mask = tf.convert_to_tensor( - [ - self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) - for val in labels.numpy().tolist() - ], - ) - special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool) - masked_indices = masked_indices & ~special_tokens_mask - if self.tokenizer.pad_token is not None: - padding_mask = labels == self.tokenizer.pad_token_id - masked_indices = masked_indices & ~padding_mask - - # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc. - non_func_mask = ~(padding_mask | special_tokens_mask) - - inputs = tf.where(masked_indices, self.tokenizer.mask_token_id, inputs) - labels = tf.where(masked_indices, labels, -100) # We only compute loss on masked tokens - - perm_mask = [] - - for i in range(len(labels)): - # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will - # determine which tokens a given token can attend to (encoded in `perm_mask`). - # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length - # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation, - # we assume that reused length is half of sequence length and permutation length is equal to reused length. - # This requires that the sequence length be even. - - # Create a linear factorisation order - # tf.range is the equivalent of torch.arange - perm_index = tf.range(labels_shape[1]) - # Split this into two halves, assuming that half the sequence is reused each time - perm_index = tf.transpose(tf.reshape(perm_index, (-1, labels_shape[1] // 2))) - # Permute the two halves such that they do not cross over - perm_index = tf.random.shuffle(perm_index) # Shuffles along the first dimension - # Flatten this out into the desired permuted factorisation order - perm_index = tf.reshape(tf.transpose(perm_index), (-1,)) - # Set the permutation indices of non-masked (non-functional) tokens to the - # smallest index (-1) so that: - # (1) They can be seen by all other positions - # (2) They cannot see masked positions, so there won't be information leak - perm_index = tf.where(~masked_indices[i] & non_func_mask[i], -1, perm_index) - # The logic for whether the i-th token can attend on the j-th token based on the factorisation order: - # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token - # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token - perm_mask.append( - (tf.reshape(perm_index, (labels_shape[1], 1)) <= tf.reshape(perm_index, (1, labels_shape[1]))) - & masked_indices[i] - ) - perm_mask = tf.stack(perm_mask, axis=0) - - return tf.cast(inputs, tf.int64), tf.cast(perm_mask, tf.float32), target_mapping, tf.cast(labels, tf.int64) - def numpy_mask_tokens(self, inputs: Any) -> tuple[Any, Any, Any, Any]: """ The masked tokens to be predicted for a particular sequence are determined by the following algorithm: diff --git a/src/transformers/data/processors/glue.py b/src/transformers/data/processors/glue.py index e005c9bcda13..abf03c917202 100644 --- a/src/transformers/data/processors/glue.py +++ b/src/transformers/data/processors/glue.py @@ -17,18 +17,14 @@ import os import warnings -from dataclasses import asdict from enum import Enum from typing import Optional, Union from ...tokenization_utils import PreTrainedTokenizer -from ...utils import is_tf_available, logging +from ...utils import logging from .utils import DataProcessor, InputExample, InputFeatures -if is_tf_available(): - import tensorflow as tf - logger = logging.get_logger(__name__) DEPRECATION_WARNING = ( @@ -39,7 +35,7 @@ def glue_convert_examples_to_features( - examples: Union[list[InputExample], "tf.data.Dataset"], + examples: list[InputExample], tokenizer: PreTrainedTokenizer, max_length: Optional[int] = None, task=None, @@ -50,7 +46,7 @@ def glue_convert_examples_to_features( Loads a data file into a list of `InputFeatures` Args: - examples: List of `InputExamples` or `tf.data.Dataset` containing the examples. + examples: List of `InputExamples` containing the examples. tokenizer: Instance of a tokenizer that will tokenize the examples max_length: Maximum example length. Defaults to the tokenizer's max_len task: GLUE task @@ -58,54 +54,15 @@ def glue_convert_examples_to_features( output_mode: String indicating the output mode. Either `regression` or `classification` Returns: - If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the task-specific - features. If the input is a list of `InputExamples`, will return a list of task-specific `InputFeatures` which - can be fed to the model. + Will return a list of task-specific `InputFeatures` which can be fed to the model. """ warnings.warn(DEPRECATION_WARNING.format("function"), FutureWarning) - if is_tf_available() and isinstance(examples, tf.data.Dataset): - if task is None: - raise ValueError("When calling glue_convert_examples_to_features from TF, the task parameter is required.") - return _tf_glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task) return _glue_convert_examples_to_features( examples, tokenizer, max_length=max_length, task=task, label_list=label_list, output_mode=output_mode ) -if is_tf_available(): - - def _tf_glue_convert_examples_to_features( - examples: tf.data.Dataset, - tokenizer: PreTrainedTokenizer, - task=str, - max_length: Optional[int] = None, - ) -> tf.data.Dataset: - """ - Returns: - A `tf.data.Dataset` containing the task-specific features. - - """ - processor = glue_processors[task]() - examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples] - features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task) - label_type = tf.float32 if task == "sts-b" else tf.int64 - - def gen(): - for ex in features: - d = {k: v for k, v in asdict(ex).items() if v is not None} - label = d.pop("label") - yield (d, label) - - input_names = tokenizer.model_input_names - - return tf.data.Dataset.from_generator( - gen, - (dict.fromkeys(input_names, tf.int32), label_type), - ({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])), - ) - - def _glue_convert_examples_to_features( examples: list[InputExample], tokenizer: PreTrainedTokenizer, diff --git a/src/transformers/data/processors/squad.py b/src/transformers/data/processors/squad.py index 5f37eb018133..e8af1549a86e 100644 --- a/src/transformers/data/processors/squad.py +++ b/src/transformers/data/processors/squad.py @@ -24,7 +24,7 @@ from ...models.bert.tokenization_bert import whitespace_tokenize from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy -from ...utils import is_tf_available, is_torch_available, is_torch_hpu_available, logging +from ...utils import is_torch_available, is_torch_hpu_available, logging from .utils import DataProcessor @@ -36,8 +36,6 @@ import torch from torch.utils.data import TensorDataset -if is_tf_available(): - import tensorflow as tf logger = logging.get_logger(__name__) @@ -244,7 +242,6 @@ def squad_convert_example_to_features( cls_index = span["input_ids"].index(tokenizer.cls_token_id) # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) - # Original TF implementation also keep the classification token (set to 0) p_mask = np.ones_like(span["token_type_ids"]) if tokenizer.padding_side == "right": p_mask[len(truncated_query) + sequence_added_tokens :] = 0 @@ -338,8 +335,8 @@ def squad_convert_examples_to_features( max_query_length: The maximum length of the query. is_training: whether to create features for model evaluation or model training. padding_strategy: Default to "max_length". Which padding strategy to use - return_dataset: Default False. Either 'pt' or 'tf'. - if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset + return_dataset: Default False. Can also be 'pt'. + if 'pt': returns a torch.data.TensorDataset. threads: multiple processing threads. @@ -430,110 +427,6 @@ def squad_convert_examples_to_features( ) return features, dataset - elif return_dataset == "tf": - if not is_tf_available(): - raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.") - - def gen(): - for i, ex in enumerate(features): - if ex.token_type_ids is None: - yield ( - { - "input_ids": ex.input_ids, - "attention_mask": ex.attention_mask, - "feature_index": i, - "qas_id": ex.qas_id, - }, - { - "start_positions": ex.start_position, - "end_positions": ex.end_position, - "cls_index": ex.cls_index, - "p_mask": ex.p_mask, - "is_impossible": ex.is_impossible, - }, - ) - else: - yield ( - { - "input_ids": ex.input_ids, - "attention_mask": ex.attention_mask, - "token_type_ids": ex.token_type_ids, - "feature_index": i, - "qas_id": ex.qas_id, - }, - { - "start_positions": ex.start_position, - "end_positions": ex.end_position, - "cls_index": ex.cls_index, - "p_mask": ex.p_mask, - "is_impossible": ex.is_impossible, - }, - ) - - # Why have we split the batch into a tuple? PyTorch just has a list of tensors. - if "token_type_ids" in tokenizer.model_input_names: - train_types = ( - { - "input_ids": tf.int32, - "attention_mask": tf.int32, - "token_type_ids": tf.int32, - "feature_index": tf.int64, - "qas_id": tf.string, - }, - { - "start_positions": tf.int64, - "end_positions": tf.int64, - "cls_index": tf.int64, - "p_mask": tf.int32, - "is_impossible": tf.int32, - }, - ) - - train_shapes = ( - { - "input_ids": tf.TensorShape([None]), - "attention_mask": tf.TensorShape([None]), - "token_type_ids": tf.TensorShape([None]), - "feature_index": tf.TensorShape([]), - "qas_id": tf.TensorShape([]), - }, - { - "start_positions": tf.TensorShape([]), - "end_positions": tf.TensorShape([]), - "cls_index": tf.TensorShape([]), - "p_mask": tf.TensorShape([None]), - "is_impossible": tf.TensorShape([]), - }, - ) - else: - train_types = ( - {"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string}, - { - "start_positions": tf.int64, - "end_positions": tf.int64, - "cls_index": tf.int64, - "p_mask": tf.int32, - "is_impossible": tf.int32, - }, - ) - - train_shapes = ( - { - "input_ids": tf.TensorShape([None]), - "attention_mask": tf.TensorShape([None]), - "feature_index": tf.TensorShape([]), - "qas_id": tf.TensorShape([]), - }, - { - "start_positions": tf.TensorShape([]), - "end_positions": tf.TensorShape([]), - "cls_index": tf.TensorShape([]), - "p_mask": tf.TensorShape([None]), - "is_impossible": tf.TensorShape([]), - }, - ) - - return tf.data.Dataset.from_generator(gen, train_types, train_shapes) else: return features diff --git a/src/transformers/data/processors/utils.py b/src/transformers/data/processors/utils.py index 462156ebac38..c3db333ce594 100644 --- a/src/transformers/data/processors/utils.py +++ b/src/transformers/data/processors/utils.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from typing import Optional, Union -from ...utils import is_tf_available, is_torch_available, logging +from ...utils import is_torch_available, logging logger = logging.get_logger(__name__) @@ -251,9 +251,7 @@ def get_features( values) Returns: - If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the - task-specific features. If the input is a list of `InputExamples`, will return a list of task-specific - `InputFeatures` which can be fed to the model. + Will return a list of task-specific `InputFeatures` which can be fed to the model. """ if max_length is None: @@ -315,21 +313,6 @@ def get_features( if return_tensors is None: return features - elif return_tensors == "tf": - if not is_tf_available(): - raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported") - import tensorflow as tf - - def gen(): - for ex in features: - yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label) - - dataset = tf.data.Dataset.from_generator( - gen, - ({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64), - ({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])), - ) - return dataset elif return_tensors == "pt": if not is_torch_available(): raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported") @@ -346,4 +329,4 @@ def gen(): dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels) return dataset else: - raise ValueError("return_tensors should be one of 'tf' or 'pt'") + raise ValueError("return_tensors should be `'pt'` or `None`") diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index a9ff39b0cc19..20d23d78db1c 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -32,12 +32,9 @@ TensorType, copy_func, download_url, - is_flax_available, - is_jax_tensor, is_numpy_array, is_offline_mode, is_remote_url, - is_tf_available, is_torch_available, is_torch_device, is_torch_dtype, @@ -110,21 +107,7 @@ def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = if not isinstance(tensor_type, TensorType): tensor_type = TensorType(tensor_type) - # Get a function reference for the correct framework - if tensor_type == TensorType.TENSORFLOW: - logger.warning_once( - "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " - "recommend migrating to PyTorch classes or pinning your version of Transformers." - ) - if not is_tf_available(): - raise ImportError( - "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." - ) - import tensorflow as tf - - as_tensor = tf.constant - is_tensor = tf.is_tensor - elif tensor_type == TensorType.PYTORCH: + if tensor_type == TensorType.PYTORCH: if not is_torch_available(): raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") import torch # noqa @@ -145,17 +128,6 @@ def as_tensor(value): return torch.tensor(value) is_tensor = torch.is_tensor - elif tensor_type == TensorType.JAX: - logger.warning_once( - "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " - "recommend migrating to PyTorch classes or pinning your version of Transformers." - ) - if not is_flax_available(): - raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") - import jax.numpy as jnp # noqa: F811 - - as_tensor = jnp.array - is_tensor = is_jax_tensor else: def as_tensor(value, dtype=None): diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index fc6f722262d9..ae214c74f37d 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -106,9 +106,6 @@ is_spacy_available, is_speech_available, is_tensor, - is_tensorflow_probability_available, - is_tf2onnx_available, - is_tf_available, is_timm_available, is_tokenizers_available, is_torch_available, diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 4fb3d32213f8..673fdae99718 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -123,71 +123,7 @@ "SynthIDTextWatermarkDetector", ] -try: - if not is_tf_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["tf_logits_process"] = [ - "TFForcedBOSTokenLogitsProcessor", - "TFForcedEOSTokenLogitsProcessor", - "TFForceTokensLogitsProcessor", - "TFLogitsProcessor", - "TFLogitsProcessorList", - "TFLogitsWarper", - "TFMinLengthLogitsProcessor", - "TFNoBadWordsLogitsProcessor", - "TFNoRepeatNGramLogitsProcessor", - "TFRepetitionPenaltyLogitsProcessor", - "TFSuppressTokensAtBeginLogitsProcessor", - "TFSuppressTokensLogitsProcessor", - "TFTemperatureLogitsWarper", - "TFTopKLogitsWarper", - "TFTopPLogitsWarper", - ] - _import_structure["tf_utils"] = [ - "TFGenerationMixin", - "TFGreedySearchDecoderOnlyOutput", - "TFGreedySearchEncoderDecoderOutput", - "TFSampleEncoderDecoderOutput", - "TFSampleDecoderOnlyOutput", - "TFBeamSearchEncoderDecoderOutput", - "TFBeamSearchDecoderOnlyOutput", - "TFBeamSampleEncoderDecoderOutput", - "TFBeamSampleDecoderOnlyOutput", - "TFContrastiveSearchEncoderDecoderOutput", - "TFContrastiveSearchDecoderOnlyOutput", - ] -try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["flax_logits_process"] = [ - "FlaxForcedBOSTokenLogitsProcessor", - "FlaxForcedEOSTokenLogitsProcessor", - "FlaxForceTokensLogitsProcessor", - "FlaxLogitsProcessor", - "FlaxLogitsProcessorList", - "FlaxLogitsWarper", - "FlaxMinLengthLogitsProcessor", - "FlaxSuppressTokensAtBeginLogitsProcessor", - "FlaxSuppressTokensLogitsProcessor", - "FlaxTemperatureLogitsWarper", - "FlaxTopKLogitsWarper", - "FlaxTopPLogitsWarper", - "FlaxWhisperTimeStampLogitsProcessor", - "FlaxNoRepeatNGramLogitsProcessor", - ] - _import_structure["flax_utils"] = [ - "FlaxGenerationMixin", - "FlaxGreedySearchOutput", - "FlaxSampleOutput", - "FlaxBeamSearchOutput", - ] if TYPE_CHECKING: from .configuration_utils import ( @@ -283,66 +219,6 @@ WatermarkDetectorOutput, ) - try: - if not is_tf_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .tf_logits_process import ( - TFForcedBOSTokenLogitsProcessor, - TFForcedEOSTokenLogitsProcessor, - TFForceTokensLogitsProcessor, - TFLogitsProcessor, - TFLogitsProcessorList, - TFLogitsWarper, - TFMinLengthLogitsProcessor, - TFNoBadWordsLogitsProcessor, - TFNoRepeatNGramLogitsProcessor, - TFRepetitionPenaltyLogitsProcessor, - TFSuppressTokensAtBeginLogitsProcessor, - TFSuppressTokensLogitsProcessor, - TFTemperatureLogitsWarper, - TFTopKLogitsWarper, - TFTopPLogitsWarper, - ) - from .tf_utils import ( - TFBeamSampleDecoderOnlyOutput, - TFBeamSampleEncoderDecoderOutput, - TFBeamSearchDecoderOnlyOutput, - TFBeamSearchEncoderDecoderOutput, - TFContrastiveSearchDecoderOnlyOutput, - TFContrastiveSearchEncoderDecoderOutput, - TFGenerationMixin, - TFGreedySearchDecoderOnlyOutput, - TFGreedySearchEncoderDecoderOutput, - TFSampleDecoderOnlyOutput, - TFSampleEncoderDecoderOutput, - ) - - try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .flax_logits_process import ( - FlaxForcedBOSTokenLogitsProcessor, - FlaxForcedEOSTokenLogitsProcessor, - FlaxForceTokensLogitsProcessor, - FlaxLogitsProcessor, - FlaxLogitsProcessorList, - FlaxLogitsWarper, - FlaxMinLengthLogitsProcessor, - FlaxNoRepeatNGramLogitsProcessor, - FlaxSuppressTokensAtBeginLogitsProcessor, - FlaxSuppressTokensLogitsProcessor, - FlaxTemperatureLogitsWarper, - FlaxTopKLogitsWarper, - FlaxTopPLogitsWarper, - FlaxWhisperTimeStampLogitsProcessor, - ) - from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput else: import sys diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index f0aeae8985b7..a83f6254e044 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -26,10 +26,8 @@ get_image_size, infer_channel_dimension_format, ) -from .utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor +from .utils import ExplicitEnum, TensorType, is_torch_tensor from .utils.import_utils import ( - is_flax_available, - is_tf_available, is_torch_available, is_vision_available, requires_backends, @@ -44,12 +42,6 @@ if is_torch_available(): import torch -if is_tf_available(): - import tensorflow as tf - -if is_flax_available(): - import jax.numpy as jnp - def to_channel_dimension_format( image: np.ndarray, @@ -160,7 +152,7 @@ def _rescale_for_pil_conversion(image): def to_pil_image( - image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"], + image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor"], do_rescale: Optional[bool] = None, image_mode: Optional[str] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -170,7 +162,7 @@ def to_pil_image( needed. Args: - image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`): + image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`): The image to convert to the `PIL.Image` format. do_rescale (`bool`, *optional*): Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default @@ -190,10 +182,8 @@ def to_pil_image( return image # Convert all tensors to numpy arrays before converting to PIL image - if is_torch_tensor(image) or is_tf_tensor(image): + if is_torch_tensor(image): image = image.numpy() - elif is_jax_tensor(image): - image = np.array(image) elif not isinstance(image, np.ndarray): raise ValueError(f"Input image type not supported: {type(image)}") @@ -556,16 +546,6 @@ def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray: return bboxes_corners -def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor": - center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1) - bboxes_corners = tf.stack( - # top left x, top left y, bottom right x, bottom right y - [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height], - axis=-1, - ) - return bboxes_corners - - # 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py def center_to_corners_format(bboxes_center: TensorType) -> TensorType: """ @@ -582,8 +562,6 @@ def center_to_corners_format(bboxes_center: TensorType) -> TensorType: return _center_to_corners_format_torch(bboxes_center) elif isinstance(bboxes_center, np.ndarray): return _center_to_corners_format_numpy(bboxes_center) - elif is_tf_tensor(bboxes_center): - return _center_to_corners_format_tf(bboxes_center) raise ValueError(f"Unsupported input type {type(bboxes_center)}") @@ -613,20 +591,6 @@ def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray: return bboxes_center -def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor": - top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1) - bboxes_center = tf.stack( - [ - (top_left_x + bottom_right_x) / 2, # center x - (top_left_y + bottom_right_y) / 2, # center y - (bottom_right_x - top_left_x), # width - (bottom_right_y - top_left_y), # height - ], - axis=-1, - ) - return bboxes_center - - def corners_to_center_format(bboxes_corners: TensorType) -> TensorType: """ Converts bounding boxes from corners format to center format. @@ -641,8 +605,6 @@ def corners_to_center_format(bboxes_corners: TensorType) -> TensorType: return _corners_to_center_format_torch(bboxes_corners) elif isinstance(bboxes_corners, np.ndarray): return _corners_to_center_format_numpy(bboxes_corners) - elif is_tf_tensor(bboxes_corners): - return _corners_to_center_format_tf(bboxes_corners) raise ValueError(f"Unsupported input type {type(bboxes_corners)}") diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 2079c21f3b0c..1d988f99379c 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -24,9 +24,7 @@ from .utils import ( ExplicitEnum, - is_jax_tensor, is_numpy_array, - is_tf_tensor, is_torch_available, is_torch_tensor, is_torchvision_available, @@ -107,8 +105,6 @@ class ImageType(ExplicitEnum): PIL = "pillow" TORCH = "torch" NUMPY = "numpy" - TENSORFLOW = "tensorflow" - JAX = "jax" def get_image_type(image): @@ -118,15 +114,11 @@ def get_image_type(image): return ImageType.TORCH if is_numpy_array(image): return ImageType.NUMPY - if is_tf_tensor(image): - return ImageType.TENSORFLOW - if is_jax_tensor(image): - return ImageType.JAX raise ValueError(f"Unrecognized image type {type(image)}") def is_valid_image(img): - return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img) + return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) def is_valid_list_of_images(images: list): @@ -205,8 +197,7 @@ def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]: ) return images raise ValueError( - "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or " - f"jax.ndarray, but got {type(images)}." + f"Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, or torch.Tensor, but got {type(images)}." ) @@ -570,7 +561,6 @@ def validate_preprocess_arguments( raise ValueError("`size` and `resample/interpolation` must be specified if `do_resize` is `True`.") -# In the future we can add a TF implementation here when we have TF models. class ImageFeatureExtractionMixin: """ Mixin that contain utilities for preparing image features. diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py deleted file mode 100644 index ab7fc4615b47..000000000000 --- a/src/transformers/keras_callbacks.py +++ /dev/null @@ -1,413 +0,0 @@ -import logging -import os -from pathlib import Path -from time import sleep -from typing import Callable, Optional, Union - -import numpy as np -import tensorflow as tf -from huggingface_hub import Repository, create_repo -from packaging.version import parse - -from . import IntervalStrategy, PreTrainedTokenizerBase -from .modelcard import TrainingSummary -from .modeling_tf_utils import keras - - -logger = logging.getLogger(__name__) - - -class KerasMetricCallback(keras.callbacks.Callback): - """ - Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be - compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string - operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the - `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute - metrics and return a dict mapping metric names to metric values. - - We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that - this example skips some post-processing for readability and simplicity, and should probably not be used as-is! - - ```py - from datasets import load_metric - - rouge_metric = load_metric("rouge") - - - def rouge_fn(predictions, labels): - decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True) - decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) - result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels) - return {key: value.mid.fmeasure * 100 for key, value in result.items()} - ``` - - The above function will return a dict containing values which will be logged like any other Keras metric: - - ``` - {'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781 - ``` - - Args: - metric_fn (`Callable`): - Metric function provided by the user. It will be called with two arguments - `predictions` and `labels`. - These contain the model's outputs and matching labels from the dataset. It should return a dict mapping - metric names to numerical values. - eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`): - Validation data to be used to generate predictions for the `metric_fn`. - output_cols (`list[str], *optional*): - A list of columns to be retained from the model output as the predictions. Defaults to all. - label_cols ('`list[str]`, *optional*'): - A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not - supplied. - batch_size (`int`, *optional*): - Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`. - predict_with_generate (`bool`, *optional*, defaults to `False`): - Whether we should use `model.generate()` to get outputs for the model. - use_xla_generation (`bool`, *optional*, defaults to `False`): - If we're generating, whether to compile model generation with XLA. This can massively increase the speed of - generation (up to 100X speedup) but will require a new XLA compilation for each input shape. When using XLA - generation, it's a good idea to pad your inputs to the same size, or to use the `pad_to_multiple_of` - argument in your `tokenizer` or `DataCollator`, which will reduce the number of unique input shapes and - save a lot of compilation time. This option has no effect is `predict_with_generate` is `False`. - generate_kwargs (`dict`, *optional*): - Keyword arguments to pass to `model.generate()` when generating. Has no effect if `predict_with_generate` - is `False`. - - """ - - def __init__( - self, - metric_fn: Callable, - eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], - output_cols: Optional[list[str]] = None, - label_cols: Optional[list[str]] = None, - batch_size: Optional[int] = None, - predict_with_generate: bool = False, - use_xla_generation: bool = False, - generate_kwargs: Optional[dict] = None, - ): - super().__init__() - self.metric_fn = metric_fn - self.batch_size = batch_size - if not isinstance(eval_dataset, tf.data.Dataset): - if batch_size is None: - raise ValueError( - "When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset " - "the batch_size argument must be set." - ) - # Wrap a tf.data.Dataset around it - eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False) - self.eval_dataset = eval_dataset - self.predict_with_generate = predict_with_generate - self.output_cols = output_cols - - # This next block attempts to parse out which elements of the dataset should be appended to the labels list - # that is passed to the metric_fn - if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2: - input_spec, label_spec = eval_dataset.element_spec - else: - input_spec = eval_dataset.element_spec - label_spec = None - if label_cols is not None: - for label in label_cols: - if label not in input_spec: - raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!") - self.label_cols = label_cols - self.use_keras_label = False - elif label_spec is not None: - # If the dataset inputs are split into a 2-tuple of inputs and labels, - # assume the second element is the labels - self.label_cols = None - self.use_keras_label = True - elif "labels" in input_spec: - self.label_cols = ["labels"] - self.use_keras_label = False - logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.") - elif "start_positions" in input_spec and "end_positions" in input_spec: - self.label_cols = ["start_positions", "end_positions"] - self.use_keras_label = False - logging.warning( - "No label_cols specified for KerasMetricCallback, assuming you want the " - "start_positions and end_positions keys." - ) - else: - raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!") - if parse(tf.__version__) < parse("2.7"): - logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!") - - self.use_xla_generation = use_xla_generation - self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs - - self.generation_function = None - - @staticmethod - def _concatenate_batches(batches, padding_index=-100): - # If all batches are unidimensional or same length, do a simple concatenation - if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches): - return np.concatenate(batches, axis=0) - - # Welp, they're not the same length. Let's do some padding - max_len = max([batch.shape[1] for batch in batches]) - num_samples = sum([batch.shape[0] for batch in batches]) - output = np.full_like( - batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:]) - ) - # i keeps track of which part of the concatenated array we're writing the next batch to - i = 0 - for batch in batches: - output[i : i + len(batch), : batch.shape[1]] = batch - i += len(batch) - return output - - def _postprocess_predictions_or_labels(self, inputs): - if isinstance(inputs[0], dict): - outputs = {} - for key in inputs[0]: - outputs[key] = self._concatenate_batches([batch[key] for batch in inputs]) - # If it's a dict with only one key, just return the array - if len(outputs) == 1: - outputs = list(outputs.values())[0] - elif isinstance(inputs[0], (tuple, list)): - outputs = [] - for input_list in zip(*inputs): - outputs.append(self._concatenate_batches(input_list)) - if len(outputs) == 1: - outputs = outputs[0] # If it's a list with only one element, just return the array - elif isinstance(inputs[0], np.ndarray): - outputs = self._concatenate_batches(inputs) - elif isinstance(inputs[0], tf.Tensor): - outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs]) - else: - raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!") - return outputs - - def on_epoch_end(self, epoch, logs=None): - if hasattr(self.model, "config"): - ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - - main_input_name = None - if self.predict_with_generate: - # This dense conditional recognizes the case where we have an encoder-decoder model, but - # avoids getting tangled up when we just have a model with a layer called 'encoder' - if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"): - main_input_name = self.model.encoder.main_input_name - else: - main_input_name = getattr(self.model, "main_input_name", "input_ids") - - if self.use_xla_generation and self.generation_function is None: - - def generation_function(inputs, attention_mask): - return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs) - - self.generation_function = tf.function(generation_function, jit_compile=True) - - prediction_list = [] - label_list = [] - - # The whole predict/generate loop is handled inside this method - for batch in self.eval_dataset: - if isinstance(batch, tuple): - batch, labels = batch - else: - labels = None - if self.predict_with_generate: - if isinstance(batch, dict): - generation_inputs = batch[main_input_name] - attention_mask = batch.get("attention_mask", None) - else: - generation_inputs = batch - attention_mask = None - if self.use_xla_generation: - predictions = self.generation_function(generation_inputs, attention_mask=attention_mask) - else: - predictions = self.model.generate( - generation_inputs, attention_mask=attention_mask, **self.generate_kwargs - ) - else: - predictions = self.model.predict_on_batch(batch) - if isinstance(predictions, dict): - # This converts any dict-subclass to a regular dict - # Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class - predictions = dict(predictions) - if self.output_cols is not None: - predictions = {key: predictions[key] for key in self.output_cols} - else: - predictions = { - key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"] - } - prediction_list.append(predictions) - if not self.use_keras_label: - labels = {key: batch[key].numpy() for key in self.label_cols} - elif isinstance(labels, dict): - labels = {key: array.numpy() for key, array in labels.items()} - elif isinstance(labels, (list, tuple)): - labels = [array.numpy() for array in labels] - elif isinstance(labels, tf.Tensor): - labels = labels.numpy() - else: - raise TypeError(f"Confused by labels of type {type(labels)}") - label_list.append(labels) - - all_preds = self._postprocess_predictions_or_labels(prediction_list) - all_labels = self._postprocess_predictions_or_labels(label_list) - - metric_output = self.metric_fn((all_preds, all_labels)) - if not isinstance(metric_output, dict): - raise TypeError( - f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}" - ) - # This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch - # in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of - # new keys in there, which will then get read by the History callback and treated like any other metric value. - # I promise that I have it in writing from Chollet that this is okay. - logs.update(metric_output) - - -class PushToHubCallback(keras.callbacks.Callback): - """ - Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can - be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such - as with the `from_pretrained` method. - - ```py - from transformers.keras_callbacks import PushToHubCallback - - push_to_hub_callback = PushToHubCallback( - output_dir="./model_save", - tokenizer=tokenizer, - hub_model_id="gpt5-7xlarge", - ) - - model.fit(train_dataset, callbacks=[push_to_hub_callback]) - ``` - - Args: - output_dir (`str`): - The output directory where the model predictions and checkpoints will be written and synced with the - repository on the Hub. - save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`): - The checkpoint save strategy to adopt during training. Possible values are: - - - `"no"`: Save is done at the end of training. - - `"epoch"`: Save is done at the end of each epoch. - - `"steps"`: Save is done every `save_steps` - save_steps (`int`, *optional*): - The number of steps between saves when using the "steps" `save_strategy`. - tokenizer (`PreTrainedTokenizerBase`, *optional*): - The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights. - hub_model_id (`str`, *optional*): - The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in - which case the model will be pushed in your namespace. Otherwise it should be the whole repository name, - for instance `"user_name/model"`, which allows you to push to an organization you are a member of with - `"organization_name/model"`. - - Will default to the name of `output_dir`. - hub_token (`str`, *optional*): - The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with - `hf auth login`. - checkpoint (`bool`, *optional*, defaults to `False`): - Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be - resumed. Only usable when `save_strategy` is `"epoch"`. - """ - - def __init__( - self, - output_dir: Union[str, Path], - save_strategy: Union[str, IntervalStrategy] = "epoch", - save_steps: Optional[int] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, - hub_model_id: Optional[str] = None, - hub_token: Optional[str] = None, - checkpoint: bool = False, - **model_card_args, - ): - super().__init__() - if checkpoint and save_strategy != "epoch": - raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!") - if isinstance(save_strategy, str): - save_strategy = IntervalStrategy(save_strategy.lower()) - self.save_strategy = save_strategy - if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0): - raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!") - self.save_steps = save_steps - output_dir = Path(output_dir) - - # Create repo and retrieve repo_id - if hub_model_id is None: - hub_model_id = output_dir.absolute().name - self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id - - self.output_dir = output_dir - self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token) - - self.tokenizer = tokenizer - self.last_job = None - self.checkpoint = checkpoint - self.training_history = None - self.model_card_args = model_card_args - - def on_train_begin(self, logs=None): - # Although we can access model.history, we have no guarantees that the History callback will fire before this - # one, so we keep track of it here too - self.training_history = [] - - def on_train_batch_end(self, batch, logs=None): - if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0: - if self.last_job is not None and not self.last_job.is_done: - return # The last upload is still running, don't start another - self.model.save_pretrained(self.output_dir) - if self.tokenizer is not None: - self.tokenizer.save_pretrained(self.output_dir) - _, self.last_job = self.repo.push_to_hub( - commit_message=f"Training in progress steps {batch}", blocking=False - ) - - def on_epoch_end(self, epoch, logs=None): - logs = logs.copy() # Don't accidentally write things that Keras will read later - if "epoch" not in logs: - logs["epoch"] = epoch - self.training_history.append(logs) - if self.save_strategy == IntervalStrategy.EPOCH: - if self.last_job is not None and not self.last_job.is_done: - return # The last upload is still running, don't start another - self.model.save_pretrained(self.output_dir) - if self.tokenizer is not None: - self.tokenizer.save_pretrained(self.output_dir) - if self.checkpoint: - checkpoint_dir = os.path.join(self.output_dir, "checkpoint") - self.model._save_checkpoint(checkpoint_dir, epoch) - train_summary = TrainingSummary.from_keras( - model=self.model, - model_name=self.hub_model_id, - keras_history=self.training_history, - **self.model_card_args, - ) - model_card = train_summary.to_model_card() - with (self.output_dir / "README.md").open("w") as f: - f.write(model_card) - _, self.last_job = self.repo.push_to_hub( - commit_message=f"Training in progress epoch {epoch}", blocking=False - ) - - def on_train_end(self, logs=None): - # Makes sure the latest version of the model is uploaded - if self.last_job is not None and not self.last_job.is_done: - logging.info("Pushing the last epoch to the Hub, this may take a while...") - while not self.last_job.is_done: - sleep(1) - else: - self.model.save_pretrained(self.output_dir) - if self.tokenizer is not None: - self.tokenizer.save_pretrained(self.output_dir) - train_summary = TrainingSummary.from_keras( - model=self.model, - model_name=self.hub_model_id, - keras_history=self.training_history, - **self.model_card_args, - ) - model_card = train_summary.to_model_card() - with (self.output_dir / "README.md").open("w") as f: - f.write(model_card) - self.repo.push_to_hub(commit_message="End of training", blocking=True) diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index 8c68d8b8af10..8ba390ee7cf5 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -51,7 +51,6 @@ cached_file, is_datasets_available, is_offline_mode, - is_tf_available, is_tokenizers_available, is_torch_available, logging, @@ -256,11 +255,6 @@ def to_json_file(self, json_file_path): should probably proofread and complete it, then remove this comment. --> """ -AUTOGENERATED_KERAS_COMMENT = """ - -""" - TASK_TAG_TO_NAME_MAPPING = { "fill-mask": "Masked Language Modeling", @@ -483,8 +477,6 @@ def to_model_card(self): # Now the model card for realsies. if self.source == "trainer": model_card += AUTOGENERATED_TRAINER_COMMENT - else: - model_card += AUTOGENERATED_KERAS_COMMENT model_card += f"\n# {self.model_name}\n\n" @@ -538,10 +530,6 @@ def to_model_card(self): import torch model_card += f"- Pytorch {torch.__version__}\n" - elif self.source == "keras" and is_tf_available(): - import tensorflow as tf - - model_card += f"- TensorFlow {tf.__version__}\n" if is_datasets_available(): import datasets @@ -631,116 +619,6 @@ def from_trainer( hyperparameters=hyperparameters, ) - @classmethod - def from_keras( - cls, - model, - model_name, - keras_history=None, - language=None, - license=None, - tags=None, - finetuned_from=None, - tasks=None, - dataset_tags=None, - dataset=None, - dataset_args=None, - ): - # Infer default from dataset - if dataset is not None: - if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None): - default_tag = dataset.builder_name - # Those are not real datasets from the Hub so we exclude them. - if default_tag not in ["csv", "json", "pandas", "parquet", "text"]: - if dataset_tags is None: - dataset_tags = [default_tag] - if dataset_args is None: - dataset_args = [dataset.config_name] - - if dataset is None and dataset_tags is not None: - dataset = dataset_tags - - # Infer default finetuned_from - if ( - finetuned_from is None - and hasattr(model.config, "_name_or_path") - and not os.path.isdir(model.config._name_or_path) - ): - finetuned_from = model.config._name_or_path - - # Infer default task tag: - if tasks is None: - model_class_name = model.__class__.__name__ - for task, mapping in TASK_MAPPING.items(): - if model_class_name in _get_mapping_values(mapping): - tasks = task - - # Add `generated_from_keras_callback` to the tags - if tags is None: - tags = ["generated_from_keras_callback"] - elif isinstance(tags, str) and tags != "generated_from_keras_callback": - tags = [tags, "generated_from_keras_callback"] - elif "generated_from_keras_callback" not in tags: - tags.append("generated_from_keras_callback") - - if keras_history is not None: - _, eval_lines, eval_results = parse_keras_history(keras_history) - else: - eval_lines = [] - eval_results = {} - hyperparameters = extract_hyperparameters_from_keras(model) - - return cls( - language=language, - license=license, - tags=tags, - model_name=model_name, - finetuned_from=finetuned_from, - tasks=tasks, - dataset_tags=dataset_tags, - dataset=dataset, - dataset_args=dataset_args, - eval_results=eval_results, - eval_lines=eval_lines, - hyperparameters=hyperparameters, - source="keras", - ) - - -def parse_keras_history(logs): - """ - Parse the `logs` of either a `keras.History` object returned by `model.fit()` or an accumulated logs `dict` - passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`. - """ - if hasattr(logs, "history"): - # This looks like a `History` object - if not hasattr(logs, "epoch"): - # This history looks empty, return empty results - return None, [], {} - logs.history["epoch"] = logs.epoch - logs = logs.history - else: - # Training logs is a list of dicts, let's invert it to a dict of lists to match a History object - logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]} - - lines = [] - for i in range(len(logs["epoch"])): - epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()} - values = {} - for k, v in epoch_dict.items(): - if k.startswith("val_"): - k = "validation_" + k[4:] - elif k != "epoch": - k = "train_" + k - splits = k.split("_") - name = " ".join([part.capitalize() for part in splits]) - values[name] = v - lines.append(values) - - eval_results = lines[-1] - - return logs, lines, eval_results - def parse_log_history(log_history): """ @@ -804,19 +682,6 @@ def parse_log_history(log_history): return train_log, lines, None -def extract_hyperparameters_from_keras(model): - from .modeling_tf_utils import keras - - hyperparameters = {} - if hasattr(model, "optimizer") and model.optimizer is not None: - hyperparameters["optimizer"] = model.optimizer.get_config() - else: - hyperparameters["optimizer"] = None - hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name - - return hyperparameters - - def _maybe_round(v, decimals=4): if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals: return f"{v:.{decimals}f}" diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 12c3e7cd99ef..bd5c7780133c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2018,19 +2018,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class for this model architecture. - - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model, - taking as arguments: - - - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint. - - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model. - - **path** (`str`) -- A path to the TensorFlow checkpoint. - - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization. - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP models, `pixel_values` for vision models and `input_values` for speech models). - - **can_record_outputs** (dict):""" + - **can_record_outputs** (dict): + """ config_class = None base_model_prefix = "" diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 4cc129366bae..0b4f9c70d914 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -15,7 +15,6 @@ """PyTorch ALBERT model.""" import math -import os from dataclasses import dataclass from typing import Optional, Union @@ -47,132 +46,6 @@ logger = logging.get_logger(__name__) -def load_tf_weights_in_albert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - print(name) - - for name, array in zip(names, arrays): - original_name = name - - # If saved from the TF HUB module - name = name.replace("module/", "") - - # Renaming and simplifying - name = name.replace("ffn_1", "ffn") - name = name.replace("bert/", "albert/") - name = name.replace("attention_1", "attention") - name = name.replace("transform/", "") - name = name.replace("LayerNorm_1", "full_layer_layer_norm") - name = name.replace("LayerNorm", "attention/LayerNorm") - name = name.replace("transformer/", "") - - # The feed forward layer had an 'intermediate' step which has been abstracted away - name = name.replace("intermediate/dense/", "") - name = name.replace("ffn/intermediate/output/dense/", "ffn_output/") - - # ALBERT attention was split between self and output which have been abstracted away - name = name.replace("/output/", "/") - name = name.replace("/self/", "/") - - # The pooler is a linear layer - name = name.replace("pooler/dense", "pooler") - - # The classifier was simplified to predictions from cls/predictions - name = name.replace("cls/predictions", "predictions") - name = name.replace("predictions/attention", "predictions") - - # Naming was changed to be more explicit - name = name.replace("embeddings/attention", "embeddings") - name = name.replace("inner_group_", "albert_layers/") - name = name.replace("group_", "albert_layer_groups/") - - # Classifier - if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name): - name = "classifier/" + name - - # No ALBERT model currently handles the next sentence prediction task - if "seq_relationship" in name: - name = name.replace("seq_relationship/output_", "sop_classifier/classifier/") - name = name.replace("weights", "weight") - - name = name.split("/") - - # Ignore the gradients applied by the LAMB/ADAM optimizers. - if ( - "adam_m" in name - or "adam_v" in name - or "AdamWeightDecayOptimizer" in name - or "AdamWeightDecayOptimizer_1" in name - or "global_step" in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except ValueError as e: - e.args += (pointer.shape, array.shape) - raise - print(f"Initialize PyTorch weight {name} from {original_name}") - pointer.data = torch.from_numpy(array) - - return model - - class AlbertEmbeddings(nn.Module): """ Construct the embeddings from word, position and token_type embeddings. @@ -184,8 +57,7 @@ def __init__(self, config: AlbertConfig): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -546,15 +418,12 @@ def forward( @auto_docstring class AlbertPreTrainedModel(PreTrainedModel): config: AlbertConfig - load_tf_weights = load_tf_weights_in_albert base_model_prefix = "albert" _supports_sdpa = True def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1337,7 +1206,6 @@ def forward( __all__ = [ - "load_tf_weights_in_albert", "AlbertPreTrainedModel", "AlbertModel", "AlbertForPreTraining", diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 61468141c570..9abac9d7e9b2 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -101,8 +101,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index b9238d8bb071..3502e0094a3e 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -16,7 +16,6 @@ """PyTorch BERT model.""" import math -import os import warnings from dataclasses import dataclass from typing import Optional, Union @@ -51,79 +50,6 @@ logger = logging.get_logger(__name__) -def load_tf_weights_in_bert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except ValueError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -133,8 +59,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -778,7 +703,6 @@ def forward(self, sequence_output, pooled_output): @auto_docstring class BertPreTrainedModel(PreTrainedModel): config: BertConfig - load_tf_weights = load_tf_weights_in_bert base_model_prefix = "bert" supports_gradient_checkpointing = True _supports_sdpa = True @@ -786,8 +710,6 @@ class BertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1797,5 +1719,4 @@ def forward( "BertLMHeadModel", "BertModel", "BertPreTrainedModel", - "load_tf_weights_in_bert", ] diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 4be87a0cd544..c42093237134 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -437,90 +437,6 @@ def forward( ) -def load_tf_weights_in_bert_generation( - model, tf_hub_path, model_class, is_encoder_named_decoder=False, is_encoder=False -): - try: - import numpy as np - import tensorflow.compat.v1 as tf - import tensorflow_hub as hub - import tensorflow_text # noqa: F401 - - tf.disable_eager_execution() - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_model = hub.Module(tf_hub_path) - init = tf.global_variables_initializer() - with tf.Session() as sess: - init.run() - all_variables = tf_model.variable_map - keep_track_variables = all_variables.copy() - for key in list(all_variables.keys()): - if "global" in key: - logger.info(f"Skipping {key}...") - continue - if not is_encoder: - model_pointer = getattr(model, model_class) - else: - model_pointer = model - is_embedding = False - logger.info(f"Trying to match {key}...") - # remove start_string = "module/bert/" - sub_layers = key.split("/")[2:] - if is_encoder_named_decoder and sub_layers[0] == "encoder": - logger.info(f"Skipping encoder layer {key} for decoder") - continue - if is_encoder and sub_layers[0] == "decoder": - logger.info(f"Skipping decoder layer {key} for encoder") - continue - for i, sub_layer in enumerate(sub_layers): - if sub_layer == "embeddings": - is_embedding = True - elif sub_layer == "LayerNorm": - is_embedding = False - if "layer" in sub_layer: - model_pointer = model_pointer.layer[int(sub_layer.split("_")[-1])] - elif sub_layer in ["kernel", "gamma"]: - model_pointer = model_pointer.weight - elif sub_layer == "beta": - model_pointer = model_pointer.bias - elif sub_layer == "encdec": - model_pointer = model_pointer.crossattention.self - elif sub_layer == "encdec_output": - model_pointer = model_pointer.crossattention.output - elif is_encoder_named_decoder and sub_layer == "decoder": - model_pointer = model_pointer.encoder - else: - if sub_layer == "attention" and "encdec" in sub_layers[i + 1]: - continue - try: - model_pointer = getattr(model_pointer, sub_layer) - except AttributeError: - logger.info(f"Skipping to initialize {key} at {sub_layer}...") - raise AttributeError - - array = np.asarray(sess.run(all_variables[key])) - if not is_embedding: - logger.info(f"Transposing numpy weight of shape {array.shape} for {key}") - array = np.transpose(array) - else: - model_pointer = model_pointer.weight - - if model_pointer.shape != array.shape: - raise ValueError(f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched") - logger.info(f"Initialize PyTorch weight {key}") - - model_pointer.data = torch.from_numpy(array.astype(np.float32)) - keep_track_variables.pop(key, None) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(keep_track_variables.keys())}") - return model - - class BertGenerationEmbeddings(nn.Module): """Construct the embeddings from word and position embeddings.""" @@ -528,8 +444,7 @@ def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -568,8 +483,6 @@ class BertGenerationPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -876,5 +789,4 @@ def forward( "BertGenerationDecoder", "BertGenerationEncoder", "BertGenerationPreTrainedModel", - "load_tf_weights_in_bert_generation", ] diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index f42b1eeaeeb1..2f776a6205f5 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -15,7 +15,6 @@ """PyTorch BigBird model.""" import math -import os from dataclasses import dataclass from typing import Optional, Union @@ -66,165 +65,6 @@ } -def load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=False): - """Load tf checkpoints in a pytorch model.""" - - def load_tf_weights_bert(init_vars, tf_path): - names = [] - tf_weights = {} - - for name, shape in init_vars: - array = tf.train.load_variable(tf_path, name) - name = name.replace("bert/encoder/LayerNorm", "bert/embeddings/LayerNorm") - logger.info(f"Loading TF weight {name} with shape {shape}") - names.append(name) - tf_weights[name] = array - - return names, tf_weights - - def load_tf_weights_trivia_qa(init_vars): - names = [] - tf_weights = {} - - for i, var in enumerate(init_vars): - name_items = var.name.split("/") - - if "transformer_scaffold" in name_items[0]: - layer_name_items = name_items[0].split("_") - if len(layer_name_items) < 3: - layer_name_items += [0] - - name_items[0] = f"bert/encoder/layer_{layer_name_items[2]}" - - name = "/".join([_TRIVIA_QA_MAPPING.get(x, x) for x in name_items])[:-2] # remove last :0 in variable - - if "self/attention/output" in name: - name = name.replace("self/attention/output", "output") - - if i >= len(init_vars) - 2: - name = name.replace("intermediate", "output") - - logger.info(f"Loading TF weight {name} with shape {var.shape}") - array = var.value().numpy() - names.append(name) - tf_weights[name] = array - - return names, tf_weights - - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - - # Load weights from TF model - init_vars = tf.saved_model.load(tf_path).variables if is_trivia_qa else tf.train.list_variables(tf_path) - - if len(init_vars) <= 0: - raise ValueError("Loaded trained variables cannot be empty.") - - pt_names = list(model.state_dict().keys()) - - if is_trivia_qa: - names, tf_weights = load_tf_weights_trivia_qa(init_vars) - else: - names, tf_weights = load_tf_weights_bert(init_vars, tf_path) - - for txt_name in names: - array = tf_weights[txt_name] - name = txt_name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - pt_name = [] - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - pt_name.append("weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - pt_name.append("bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - pt_name.append("weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - pt_name.append("classifier") - elif scope_names[0] == "transform": - pointer = getattr(pointer, "transform") - pt_name.append("transform") - if ("bias" in name) or ("kernel" in name): - pointer = getattr(pointer, "dense") - pt_name.append("dense") - elif ("beta" in name) or ("gamma" in name): - pointer = getattr(pointer, "LayerNorm") - pt_name.append("LayerNorm") - else: - try: - pointer = getattr(pointer, scope_names[0]) - pt_name.append(f"{scope_names[0]}") - except AttributeError: - logger.info(f"Skipping {m_name}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - pt_name.append(f"{num}") - if m_name[-11:] == "_embeddings" or m_name == "embeddings": - pointer = getattr(pointer, "weight") - pt_name.append("weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if len(array.shape) > len(pointer.shape) and math.prod(array.shape) == math.prod(pointer.shape): - # print(txt_name, array.shape) - if ( - txt_name.endswith("attention/self/key/kernel") - or txt_name.endswith("attention/self/query/kernel") - or txt_name.endswith("attention/self/value/kernel") - ): - array = array.transpose(1, 0, 2).reshape(pointer.shape) - elif txt_name.endswith("attention/output/dense/kernel"): - array = array.transpose(0, 2, 1).reshape(pointer.shape) - else: - array = array.reshape(pointer.shape) - - if pointer.shape != array.shape: - raise ValueError( - f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched of {txt_name}." - ) - except ValueError as e: - e.args += (pointer.shape, array.shape) - raise - pt_weight_name = ".".join(pt_name) - logger.info(f"Initialize PyTorch weight {pt_weight_name} from {txt_name}.") - pointer.data = torch.from_numpy(array) - tf_weights.pop(txt_name, None) - pt_names.remove(pt_weight_name) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") - logger.info(f"Weights not initialized in PyTorch model: {', '.join(pt_names)}.") - return model - - class BigBirdEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -235,8 +75,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -937,8 +776,6 @@ def bigbird_block_sparse_attention( @staticmethod def torch_gather_b2(params, indices): - # this operation is equivalent to tf.gather when batch_dims=2 - if params.shape[:2] != indices.shape[:2]: raise ValueError( "Make sure that the first two dimensions of params and indices are identical, but" @@ -1708,15 +1545,12 @@ def forward(self, sequence_output, pooled_output): @auto_docstring class BigBirdPreTrainedModel(PreTrainedModel): config: BigBirdConfig - load_tf_weights = load_tf_weights_in_big_bird base_model_prefix = "bert" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -2954,5 +2788,4 @@ def prepare_question_mask(q_lengths: torch.Tensor, maxlen: int): "BigBirdLayer", "BigBirdModel", "BigBirdPreTrainedModel", - "load_tf_weights_in_big_bird", ] diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 90f3c886ad93..70644c8d3df2 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -753,8 +753,6 @@ def bigbird_block_sparse_attention( @staticmethod def torch_gather_b2(params, indices): - # this operation is equivalent to tf.gather when batch_dims=2 - if params.shape[:2] != indices.shape[:2]: raise ValueError( "Make sure that the first two dimensions of params and indices are identical, but" diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 59c5be00c316..3f91eb91ae4d 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -838,8 +838,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 3a07402f739a..eb5439d93a3d 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -60,8 +60,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -672,8 +671,6 @@ class CamembertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 585961180f9e..9acbb476c2f8 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -16,7 +16,6 @@ import copy import math -import os from dataclasses import dataclass from typing import Optional, Union @@ -84,103 +83,6 @@ class CanineModelOutputWithPooling(ModelOutput): attentions: Optional[tuple[torch.FloatTensor]] = None -def load_tf_weights_in_canine(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - # also discard the cls weights (which were used for the next sentence prediction pre-training task) - if any( - n - in [ - "adam_v", - "adam_m", - "AdamWeightDecayOptimizer", - "AdamWeightDecayOptimizer_1", - "global_step", - "cls", - "autoregressive_decoder", - "char_output_weights", - ] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - # if first scope name starts with "bert", change it to "encoder" - if name[0] == "bert": - name[0] = "encoder" - # remove "embeddings" middle name of HashBucketCodepointEmbedders - elif name[1] == "embeddings": - name.remove(name[1]) - # rename segment_embeddings to token_type_embeddings - elif name[1] == "segment_embeddings": - name[1] = "token_type_embeddings" - # rename initial convolutional projection layer - elif name[1] == "initial_char_encoder": - name = ["chars_to_molecules"] + name[-2:] - # rename final convolutional projection layer - elif name[0] == "final_char_encoder" and name[1] in ["LayerNorm", "conv"]: - name = ["projection"] + name[1:] - pointer = model - for m_name in name: - if (re.fullmatch(r"[A-Za-z]+_\d+", m_name)) and "Embedder" not in m_name: - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name[-10:] in [f"Embedder_{i}" for i in range(8)]: - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class CanineEmbeddings(nn.Module): """Construct the character, position and token_type embeddings.""" @@ -197,8 +99,7 @@ def __init__(self, config): self.char_position_embeddings = nn.Embedding(config.num_hash_buckets, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -873,15 +774,12 @@ def forward( @auto_docstring class CaninePreTrainedModel(PreTrainedModel): config: CanineConfig - load_tf_weights = load_tf_weights_in_canine base_model_prefix = "canine" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv1d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1545,5 +1443,4 @@ def forward( "CanineLayer", "CanineModel", "CaninePreTrainedModel", - "load_tf_weights_in_canine", ] diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 9d81a26581dd..6c8633788b63 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1008,8 +1008,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/codegen/tokenization_codegen.py b/src/transformers/models/codegen/tokenization_codegen.py index 152b1a84fc37..d8a5a2745ae7 100644 --- a/src/transformers/models/codegen/tokenization_codegen.py +++ b/src/transformers/models/codegen/tokenization_codegen.py @@ -22,14 +22,11 @@ import numpy as np import regex as re -from ...utils import is_tf_available, is_torch_available, logging, to_py_obj +from ...utils import logging, to_py_obj if TYPE_CHECKING: - if is_torch_available(): - import torch - if is_tf_available(): - import tensorflow as tf + import torch from ...tokenization_utils import AddedToken, PreTrainedTokenizer @@ -313,7 +310,7 @@ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): def decode( self, - token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = None, truncate_before_pattern: Optional[list[str]] = None, @@ -326,7 +323,7 @@ def decode( Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. Args: - token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. diff --git a/src/transformers/models/codegen/tokenization_codegen_fast.py b/src/transformers/models/codegen/tokenization_codegen_fast.py index 7bac0db7de4e..4cbeff06ad89 100644 --- a/src/transformers/models/codegen/tokenization_codegen_fast.py +++ b/src/transformers/models/codegen/tokenization_codegen_fast.py @@ -19,14 +19,12 @@ import numpy as np -from ...utils import is_tf_available, is_torch_available, logging +from ...utils import is_torch_available, logging if TYPE_CHECKING: if is_torch_available(): import torch - if is_tf_available(): - import tensorflow as tf from ...tokenization_utils_base import BatchEncoding @@ -160,7 +158,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = def decode( self, - token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = None, truncate_before_pattern: Optional[list[str]] = None, @@ -173,7 +171,7 @@ def decode( Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. Args: - token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. diff --git a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py index 9221c463fe85..d327d3b8489e 100644 --- a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py +++ b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py @@ -55,8 +55,6 @@ ) from ...utils import ( TensorType, - is_flax_available, - is_jax_tensor, is_scipy_available, is_tf_available, is_tf_tensor, @@ -203,18 +201,10 @@ def get_numpy_to_framework_fn(arr) -> Callable: """ if isinstance(arr, np.ndarray): return np.array - if is_tf_available() and is_tf_tensor(arr): - import tensorflow as tf - - return tf.convert_to_tensor if is_torch_available() and is_torch_tensor(arr): import torch return torch.tensor - if is_flax_available() and is_jax_tensor(arr): - import jax.numpy as jnp - - return jnp.array raise ValueError(f"Cannot convert arrays of type {type(arr)}") @@ -1205,10 +1195,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -1393,10 +1381,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor.") validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 080b93fa92a6..a5bc7912540e 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -15,8 +15,6 @@ """PyTorch ConvBERT model.""" import math -import os -from operator import attrgetter from typing import Callable, Optional, Union import torch @@ -45,130 +43,6 @@ logger = logging.get_logger(__name__) -def load_tf_weights_in_convbert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - tf_data = {} - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - tf_data[name] = array - - param_mapping = { - "embeddings.word_embeddings.weight": "electra/embeddings/word_embeddings", - "embeddings.position_embeddings.weight": "electra/embeddings/position_embeddings", - "embeddings.token_type_embeddings.weight": "electra/embeddings/token_type_embeddings", - "embeddings.LayerNorm.weight": "electra/embeddings/LayerNorm/gamma", - "embeddings.LayerNorm.bias": "electra/embeddings/LayerNorm/beta", - "embeddings_project.weight": "electra/embeddings_project/kernel", - "embeddings_project.bias": "electra/embeddings_project/bias", - } - if config.num_groups > 1: - group_dense_name = "g_dense" - else: - group_dense_name = "dense" - - for j in range(config.num_hidden_layers): - param_mapping[f"encoder.layer.{j}.attention.self.query.weight"] = ( - f"electra/encoder/layer_{j}/attention/self/query/kernel" - ) - param_mapping[f"encoder.layer.{j}.attention.self.query.bias"] = ( - f"electra/encoder/layer_{j}/attention/self/query/bias" - ) - param_mapping[f"encoder.layer.{j}.attention.self.key.weight"] = ( - f"electra/encoder/layer_{j}/attention/self/key/kernel" - ) - param_mapping[f"encoder.layer.{j}.attention.self.key.bias"] = ( - f"electra/encoder/layer_{j}/attention/self/key/bias" - ) - param_mapping[f"encoder.layer.{j}.attention.self.value.weight"] = ( - f"electra/encoder/layer_{j}/attention/self/value/kernel" - ) - param_mapping[f"encoder.layer.{j}.attention.self.value.bias"] = ( - f"electra/encoder/layer_{j}/attention/self/value/bias" - ) - param_mapping[f"encoder.layer.{j}.attention.self.key_conv_attn_layer.depthwise.weight"] = ( - f"electra/encoder/layer_{j}/attention/self/conv_attn_key/depthwise_kernel" - ) - param_mapping[f"encoder.layer.{j}.attention.self.key_conv_attn_layer.pointwise.weight"] = ( - f"electra/encoder/layer_{j}/attention/self/conv_attn_key/pointwise_kernel" - ) - param_mapping[f"encoder.layer.{j}.attention.self.key_conv_attn_layer.bias"] = ( - f"electra/encoder/layer_{j}/attention/self/conv_attn_key/bias" - ) - param_mapping[f"encoder.layer.{j}.attention.self.conv_kernel_layer.weight"] = ( - f"electra/encoder/layer_{j}/attention/self/conv_attn_kernel/kernel" - ) - param_mapping[f"encoder.layer.{j}.attention.self.conv_kernel_layer.bias"] = ( - f"electra/encoder/layer_{j}/attention/self/conv_attn_kernel/bias" - ) - param_mapping[f"encoder.layer.{j}.attention.self.conv_out_layer.weight"] = ( - f"electra/encoder/layer_{j}/attention/self/conv_attn_point/kernel" - ) - param_mapping[f"encoder.layer.{j}.attention.self.conv_out_layer.bias"] = ( - f"electra/encoder/layer_{j}/attention/self/conv_attn_point/bias" - ) - param_mapping[f"encoder.layer.{j}.attention.output.dense.weight"] = ( - f"electra/encoder/layer_{j}/attention/output/dense/kernel" - ) - param_mapping[f"encoder.layer.{j}.attention.output.LayerNorm.weight"] = ( - f"electra/encoder/layer_{j}/attention/output/LayerNorm/gamma" - ) - param_mapping[f"encoder.layer.{j}.attention.output.dense.bias"] = ( - f"electra/encoder/layer_{j}/attention/output/dense/bias" - ) - param_mapping[f"encoder.layer.{j}.attention.output.LayerNorm.bias"] = ( - f"electra/encoder/layer_{j}/attention/output/LayerNorm/beta" - ) - param_mapping[f"encoder.layer.{j}.intermediate.dense.weight"] = ( - f"electra/encoder/layer_{j}/intermediate/{group_dense_name}/kernel" - ) - param_mapping[f"encoder.layer.{j}.intermediate.dense.bias"] = ( - f"electra/encoder/layer_{j}/intermediate/{group_dense_name}/bias" - ) - param_mapping[f"encoder.layer.{j}.output.dense.weight"] = ( - f"electra/encoder/layer_{j}/output/{group_dense_name}/kernel" - ) - param_mapping[f"encoder.layer.{j}.output.dense.bias"] = ( - f"electra/encoder/layer_{j}/output/{group_dense_name}/bias" - ) - param_mapping[f"encoder.layer.{j}.output.LayerNorm.weight"] = ( - f"electra/encoder/layer_{j}/output/LayerNorm/gamma" - ) - param_mapping[f"encoder.layer.{j}.output.LayerNorm.bias"] = f"electra/encoder/layer_{j}/output/LayerNorm/beta" - - for param in model.named_parameters(): - param_name = param[0] - retriever = attrgetter(param_name) - result = retriever(model) - tf_name = param_mapping[param_name] - value = torch.from_numpy(tf_data[tf_name]) - logger.info(f"TF: {tf_name}, PT: {param_name} ") - if tf_name.endswith("/kernel"): - if not tf_name.endswith("/intermediate/g_dense/kernel"): - if not tf_name.endswith("/output/g_dense/kernel"): - value = value.T - if tf_name.endswith("/depthwise_kernel"): - value = value.permute(1, 2, 0) # 2, 0, 1 - if tf_name.endswith("/pointwise_kernel"): - value = value.permute(2, 1, 0) # 2, 1, 0 - if tf_name.endswith("/conv_attn_key/bias"): - value = value.unsqueeze(-1) - result.data = value - return model - - class ConvBertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -178,8 +52,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -232,15 +105,12 @@ def forward( @auto_docstring class ConvBertPreTrainedModel(PreTrainedModel): config: ConvBertConfig - load_tf_weights = load_tf_weights_in_convbert base_model_prefix = "convbert" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv1d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1330,5 +1200,4 @@ def forward( "ConvBertLayer", "ConvBertModel", "ConvBertPreTrainedModel", - "load_tf_weights_in_convbert", ] diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index f866dd9144a6..7f6843e98ca9 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -49,7 +49,7 @@ # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Data2VecText -class Data2VecTextForTextEmbeddings(nn.Module): +class Data2VecTextEmbeddings(nn.Module): """ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. """ @@ -61,8 +61,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index f9c68fcbdeae..ebf6c28c8765 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -40,63 +40,6 @@ logger = logging.get_logger(__name__) -# Copied from transformers.models.gpt2.modeling_gpt2.load_tf_weights_in_gpt2 -def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): - """Load tf checkpoints in a pytorch model""" - try: - import re - - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(gpt2_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array.squeeze()) - - for name, array in zip(names, arrays): - name = name[6:] # skip "model/" - name = name.split("/") - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+\d+", m_name): - scope_names = re.split(r"(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "w" or scope_names[0] == "g": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "b": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "wpe" or scope_names[0] == "wte": - pointer = getattr(pointer, scope_names[0]) - pointer = getattr(pointer, "weight") - else: - pointer = getattr(pointer, scope_names[0]) - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except ValueError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - # Copied from transformers.models.gpt2.modeling_gpt2.eager_attention_forward def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs): attn_weights = torch.matmul(query, key.transpose(-1, -2)) @@ -456,7 +399,6 @@ def forward( @auto_docstring class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): config: DecisionTransformerConfig - load_tf_weights = load_tf_weights_in_gpt2 base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py index c6875eb9b8f8..27998d605502 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py @@ -55,11 +55,7 @@ ) from ...utils import ( TensorType, - is_flax_available, - is_jax_tensor, is_scipy_available, - is_tf_available, - is_tf_tensor, is_torch_available, is_torch_tensor, is_vision_available, @@ -201,18 +197,10 @@ def get_numpy_to_framework_fn(arr) -> Callable: """ if isinstance(arr, np.ndarray): return np.array - if is_tf_available() and is_tf_tensor(arr): - import tensorflow as tf - - return tf.convert_to_tensor if is_torch_available() and is_torch_tensor(arr): import torch return torch.tensor - if is_flax_available() and is_jax_tensor(arr): - import jax.numpy as jnp - - return jnp.array raise ValueError(f"Cannot convert arrays of type {type(arr)}") @@ -1203,10 +1191,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -1391,10 +1377,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor.") validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. diff --git a/src/transformers/models/deprecated/deta/image_processing_deta.py b/src/transformers/models/deprecated/deta/image_processing_deta.py index 434d25a1ab51..15220603bb40 100644 --- a/src/transformers/models/deprecated/deta/image_processing_deta.py +++ b/src/transformers/models/deprecated/deta/image_processing_deta.py @@ -50,10 +50,6 @@ validate_preprocess_arguments, ) from ....utils import ( - is_flax_available, - is_jax_tensor, - is_tf_available, - is_tf_tensor, is_torch_available, is_torch_tensor, is_torchvision_available, @@ -190,18 +186,11 @@ def get_numpy_to_framework_fn(arr) -> Callable: """ if isinstance(arr, np.ndarray): return np.array - if is_tf_available() and is_tf_tensor(arr): - import tensorflow as tf - - return tf.convert_to_tensor if is_torch_available() and is_torch_tensor(arr): import torch return torch.tensor - if is_flax_available() and is_jax_tensor(arr): - import jax.numpy as jnp - return jnp.array raise ValueError(f"Cannot convert arrays of type {type(arr)}") @@ -829,10 +818,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -1024,10 +1011,7 @@ def preprocess( annotations = [annotations] if annotations is not None else None if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor.") if annotations is not None and len(images) != len(annotations): raise ValueError( f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." diff --git a/src/transformers/models/deprecated/jukebox/tokenization_jukebox.py b/src/transformers/models/deprecated/jukebox/tokenization_jukebox.py index ec2162db2cce..473d23d49565 100644 --- a/src/transformers/models/deprecated/jukebox/tokenization_jukebox.py +++ b/src/transformers/models/deprecated/jukebox/tokenization_jukebox.py @@ -26,8 +26,8 @@ from ....tokenization_utils import AddedToken, PreTrainedTokenizer from ....tokenization_utils_base import BatchEncoding -from ....utils import TensorType, is_flax_available, is_tf_available, is_torch_available, logging -from ....utils.generic import _is_jax, _is_numpy +from ....utils import TensorType, is_torch_available, logging +from ....utils.generic import _is_numpy logger = logging.get_logger(__name__) @@ -279,30 +279,13 @@ def convert_to_tensors( if not isinstance(tensor_type, TensorType): tensor_type = TensorType(tensor_type) - # Get a function reference for the correct framework - if tensor_type == TensorType.TENSORFLOW: - if not is_tf_available(): - raise ImportError( - "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." - ) - import tensorflow as tf - - as_tensor = tf.constant - is_tensor = tf.is_tensor - elif tensor_type == TensorType.PYTORCH: + if tensor_type == TensorType.PYTORCH: if not is_torch_available(): raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") import torch as_tensor = torch.tensor is_tensor = torch.is_tensor - elif tensor_type == TensorType.JAX: - if not is_flax_available(): - raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") - import jax.numpy as jnp # noqa: F811 - - as_tensor = jnp.array - is_tensor = _is_jax else: as_tensor = np.asarray is_tensor = _is_numpy diff --git a/src/transformers/models/deprecated/nezha/modeling_nezha.py b/src/transformers/models/deprecated/nezha/modeling_nezha.py index ddfecac9f506..5fcc0318a50a 100644 --- a/src/transformers/models/deprecated/nezha/modeling_nezha.py +++ b/src/transformers/models/deprecated/nezha/modeling_nezha.py @@ -15,7 +15,6 @@ """PyTorch Nezha model.""" import math -import os import warnings from dataclasses import dataclass from typing import Optional, Union @@ -57,79 +56,6 @@ _CONFIG_FOR_DOC = "NezhaConfig" -def load_tf_weights_in_nezha(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class NezhaRelativePositionsEncoding(nn.Module): """Implement the Functional Relative Position Encoding""" @@ -703,15 +629,12 @@ class NezhaPreTrainedModel(PreTrainedModel): """ config: NezhaConfig - load_tf_weights = load_tf_weights_in_nezha base_model_prefix = "nezha" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py index f92bc07a8bfb..57ffcae80e56 100755 --- a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py @@ -16,7 +16,6 @@ """PyTorch QDQBERT model.""" import math -import os import warnings from typing import Optional, Union @@ -71,79 +70,6 @@ _CONFIG_FOR_DOC = "QDQBertConfig" -def load_tf_weights_in_qdqbert(model, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class QDQBertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -712,15 +638,12 @@ class QDQBertPreTrainedModel(PreTrainedModel): """ config: QDQBertConfig - load_tf_weights = load_tf_weights_in_qdqbert base_model_prefix = "bert" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1732,5 +1655,4 @@ def forward( "QDQBertLMHeadModel", "QDQBertModel", "QDQBertPreTrainedModel", - "load_tf_weights_in_qdqbert", ] diff --git a/src/transformers/models/deprecated/realm/modeling_realm.py b/src/transformers/models/deprecated/realm/modeling_realm.py index 9e2de5c9c1c4..bc177fcde7be 100644 --- a/src/transformers/models/deprecated/realm/modeling_realm.py +++ b/src/transformers/models/deprecated/realm/modeling_realm.py @@ -15,7 +15,6 @@ """PyTorch REALM model.""" import math -import os from dataclasses import dataclass from typing import Optional, Union @@ -46,113 +45,6 @@ _CONFIG_FOR_DOC = "RealmConfig" -def load_tf_weights_in_realm(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - if isinstance(model, RealmReader) and "reader" not in name: - logger.info(f"Skipping {name} as it is not {model.__class__.__name__}'s parameter") - continue - - # For pretrained openqa reader - if (name.startswith("bert") or name.startswith("cls")) and isinstance(model, RealmForOpenQA): - name = name.replace("bert/", "reader/realm/") - name = name.replace("cls/", "reader/cls/") - - # For pretrained encoder - if (name.startswith("bert") or name.startswith("cls")) and isinstance(model, RealmKnowledgeAugEncoder): - name = name.replace("bert/", "realm/") - - # For finetuned reader - if name.startswith("reader"): - reader_prefix = "" if isinstance(model, RealmReader) else "reader/" - name = name.replace("reader/module/bert/", f"{reader_prefix}realm/") - name = name.replace("reader/module/cls/", f"{reader_prefix}cls/") - name = name.replace("reader/dense/", f"{reader_prefix}qa_outputs/dense_intermediate/") - name = name.replace("reader/dense_1/", f"{reader_prefix}qa_outputs/dense_output/") - name = name.replace("reader/layer_normalization", f"{reader_prefix}qa_outputs/layer_normalization") - - # For embedder and scorer - if name.startswith("module/module/module/"): # finetuned - embedder_prefix = "" if isinstance(model, RealmEmbedder) else "embedder/" - name = name.replace("module/module/module/module/bert/", f"{embedder_prefix}realm/") - name = name.replace("module/module/module/LayerNorm/", f"{embedder_prefix}cls/LayerNorm/") - name = name.replace("module/module/module/dense/", f"{embedder_prefix}cls/dense/") - name = name.replace("module/module/module/module/cls/predictions/", f"{embedder_prefix}cls/predictions/") - name = name.replace("module/module/module/bert/", f"{embedder_prefix}realm/") - name = name.replace("module/module/module/cls/predictions/", f"{embedder_prefix}cls/predictions/") - elif name.startswith("module/module/"): # pretrained - embedder_prefix = "" if isinstance(model, RealmEmbedder) else "embedder/" - name = name.replace("module/module/LayerNorm/", f"{embedder_prefix}cls/LayerNorm/") - name = name.replace("module/module/dense/", f"{embedder_prefix}cls/dense/") - - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - assert pointer.shape == array.shape, ( - f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - ) - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class RealmEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -945,14 +837,11 @@ class RealmPreTrainedModel(PreTrainedModel): """ config: RealmConfig - load_tf_weights = load_tf_weights_in_realm base_model_prefix = "realm" def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1851,5 +1740,4 @@ def forward( "RealmPreTrainedModel", "RealmReader", "RealmScorer", - "load_tf_weights_in_realm", ] diff --git a/src/transformers/models/deprecated/realm/retrieval_realm.py b/src/transformers/models/deprecated/realm/retrieval_realm.py index b5e47abb1179..354ca2aba63a 100644 --- a/src/transformers/models/deprecated/realm/retrieval_realm.py +++ b/src/transformers/models/deprecated/realm/retrieval_realm.py @@ -31,16 +31,6 @@ logger = logging.get_logger(__name__) -def convert_tfrecord_to_np(block_records_path: str, num_block_records: int) -> np.ndarray: - import tensorflow.compat.v1 as tf - - blocks_dataset = tf.data.TFRecordDataset(block_records_path, buffer_size=512 * 1024 * 1024) - blocks_dataset = blocks_dataset.batch(num_block_records, drop_remainder=True) - np_record = next(blocks_dataset.take(1).as_numpy_iterator()) - - return np_record - - class ScaNNSearcher: """Note that ScaNNSearcher cannot currently be used within the model. In future versions, it might however be included.""" diff --git a/src/transformers/models/deprecated/retribert/modeling_retribert.py b/src/transformers/models/deprecated/retribert/modeling_retribert.py index 06806e8e6d0b..926d7551e51b 100644 --- a/src/transformers/models/deprecated/retribert/modeling_retribert.py +++ b/src/transformers/models/deprecated/retribert/modeling_retribert.py @@ -40,7 +40,6 @@ class RetriBertPreTrainedModel(PreTrainedModel): """ config: RetriBertConfig - load_tf_weights = None base_model_prefix = "retribert" def _init_weights(self, module): diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index b6ae410c1474..1b4126f9ef20 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -15,7 +15,6 @@ """PyTorch TrajectoryTransformer model.""" import math -import os from dataclasses import dataclass from typing import Optional, Union @@ -43,79 +42,6 @@ _CONFIG_FOR_DOC = "TrajectoryTransformerConfig" -def load_tf_weights_in_trajectory_transformer(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - @dataclass class TrajectoryTransformerOutput(ModelOutput): """ @@ -154,7 +80,6 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel): """ config: TrajectoryTransformerConfig - load_tf_weights = load_tf_weights_in_trajectory_transformer base_model_prefix = "trajectory_transformer" main_input_name = "trajectories" supports_gradient_checkpointing = True @@ -598,5 +523,4 @@ def forward( __all__ = [ "TrajectoryTransformerModel", "TrajectoryTransformerPreTrainedModel", - "load_tf_weights_in_trajectory_transformer", ] diff --git a/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py index 19c3fb0bd485..a7b4825e5fcd 100644 --- a/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py @@ -44,133 +44,6 @@ _CONFIG_FOR_DOC = "TransfoXLConfig" -def build_tf_to_pytorch_map(model, config): - """ - A map of modules from TF to PyTorch. This time I use a map to keep the PyTorch model as identical to the original - PyTorch model as possible. - """ - tf_to_pt_map = {} - - if hasattr(model, "transformer"): - # We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax - tf_to_pt_map.update( - { - "transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight, - "transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias, - } - ) - for i, (out_l, proj_l, tie_proj) in enumerate( - zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs) - ): - layer_str = f"transformer/adaptive_softmax/cutoff_{i}/" - if config.tie_word_embeddings: - tf_to_pt_map.update({layer_str + "b": out_l.bias}) - else: - raise NotImplementedError - # I don't think this is implemented in the TF code - tf_to_pt_map.update({layer_str + "lookup_table": out_l.weight, layer_str + "b": out_l.bias}) - if not tie_proj: - tf_to_pt_map.update({layer_str + "proj": proj_l}) - # Now load the rest of the transformer - model = model.transformer - - # Embeddings - for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)): - layer_str = f"transformer/adaptive_embed/cutoff_{i}/" - tf_to_pt_map.update({layer_str + "lookup_table": embed_l.weight, layer_str + "proj_W": proj_l}) - - # Transformer blocks - for i, b in enumerate(model.layers): - layer_str = f"transformer/layer_{i}/" - tf_to_pt_map.update( - { - layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight, - layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias, - layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight, - layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight, - layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight, - layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight, - layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias, - layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight, - layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias, - layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight, - layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias, - } - ) - - # Relative positioning biases - if config.untie_r: - r_r_list = [] - r_w_list = [] - for b in model.layers: - r_r_list.append(b.dec_attn.r_r_bias) - r_w_list.append(b.dec_attn.r_w_bias) - else: - r_r_list = [model.r_r_bias] - r_w_list = [model.r_w_bias] - tf_to_pt_map.update({"transformer/r_r_bias": r_r_list, "transformer/r_w_bias": r_w_list}) - return tf_to_pt_map - - -def load_tf_weights_in_transfo_xl(model, config, tf_path): - """Load tf checkpoints in a pytorch model""" - try: - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - # Build TF to PyTorch weights loading map - tf_to_pt_map = build_tf_to_pytorch_map(model, config) - - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - tf_weights = {} - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - tf_weights[name] = array - - for name, pointer in tf_to_pt_map.items(): - assert name in tf_weights - array = tf_weights[name] - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if "kernel" in name or "proj" in name: - array = np.transpose(array) - if ("r_r_bias" in name or "r_w_bias" in name) and len(pointer) > 1: - # Here we will split the TF weights - assert len(pointer) == array.shape[0] - for i, p_i in enumerate(pointer): - arr_i = array[i, ...] - try: - assert p_i.shape == arr_i.shape - except AssertionError as e: - e.args += (p_i.shape, arr_i.shape) - raise - logger.info(f"Initialize PyTorch weight {name} for layer {i}") - p_i.data = torch.from_numpy(arr_i) - else: - try: - assert pointer.shape == array.shape, ( - f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - ) - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - tf_weights.pop(name, None) - tf_weights.pop(name + "/Adam", None) - tf_weights.pop(name + "/Adam_1", None) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}") - return model - - class PositionalEmbedding(nn.Module): def __init__(self, demb): super().__init__() @@ -459,7 +332,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel): """ config: TransfoXLConfig - load_tf_weights = load_tf_weights_in_transfo_xl base_model_prefix = "transformer" def _init_weight(self, weight): @@ -1299,5 +1171,4 @@ def forward( "TransfoXLLMHeadModel", "TransfoXLModel", "TransfoXLPreTrainedModel", - "load_tf_weights_in_transfo_xl", ] diff --git a/src/transformers/models/detr/image_processing_detr.py b/src/transformers/models/detr/image_processing_detr.py index f29bd48a5934..80287942b5f9 100644 --- a/src/transformers/models/detr/image_processing_detr.py +++ b/src/transformers/models/detr/image_processing_detr.py @@ -54,11 +54,7 @@ ) from ...utils import ( TensorType, - is_flax_available, - is_jax_tensor, is_scipy_available, - is_tf_available, - is_tf_tensor, is_torch_available, is_torch_tensor, is_vision_available, @@ -198,18 +194,10 @@ def get_numpy_to_framework_fn(arr) -> Callable: """ if isinstance(arr, np.ndarray): return np.array - if is_tf_available() and is_tf_tensor(arr): - import tensorflow as tf - - return tf.convert_to_tensor if is_torch_available() and is_torch_tensor(arr): import torch return torch.tensor - if is_flax_available() and is_jax_tensor(arr): - import jax.numpy as jnp - - return jnp.array raise ValueError(f"Cannot convert arrays of type {type(arr)}") @@ -1179,10 +1167,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -1366,10 +1352,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor.") validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index bc116b231af1..8024e01a59bf 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -562,7 +562,6 @@ def forward( @auto_docstring class DistilBertPreTrainedModel(PreTrainedModel): config: DistilBertConfig - load_tf_weights = None base_model_prefix = "distilbert" supports_gradient_checkpointing = True _supports_flash_attn = True diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index f1ae00a02e07..03178c4e8564 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -247,7 +247,6 @@ class DPRPretrainedContextEncoder(DPRPreTrainedModel): """ config: DPRConfig - load_tf_weights = None base_model_prefix = "ctx_encoder" @@ -258,7 +257,6 @@ class DPRPretrainedQuestionEncoder(DPRPreTrainedModel): """ config: DPRConfig - load_tf_weights = None base_model_prefix = "question_encoder" @@ -269,7 +267,6 @@ class DPRPretrainedReader(DPRPreTrainedModel): """ config: DPRConfig - load_tf_weights = None base_model_prefix = "span_predictor" diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index a10b0b658337..c944fb45a38d 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -15,7 +15,6 @@ """PyTorch ELECTRA model.""" import math -import os from dataclasses import dataclass from typing import Callable, Optional, Union @@ -47,88 +46,6 @@ logger = logging.get_logger(__name__) -def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - for name, array in zip(names, arrays): - original_name: str = name - - try: - if isinstance(model, ElectraForMaskedLM): - name = name.replace("electra/embeddings/", "generator/embeddings/") - - if discriminator_or_generator == "generator": - name = name.replace("electra/", "discriminator/") - name = name.replace("generator/", "electra/") - - name = name.replace("dense_1", "dense_prediction") - name = name.replace("generator_predictions/output_bias", "generator_lm_head/bias") - - name = name.split("/") - # print(original_name, name) - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any(n in ["global_step", "temperature"] for n in name): - logger.info(f"Skipping {original_name}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - pointer = getattr(pointer, scope_names[0]) - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name.endswith("_embeddings"): - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except ValueError as e: - e.args += (pointer.shape, array.shape) - raise - print(f"Initialize PyTorch weight {name}", original_name) - pointer.data = torch.from_numpy(array) - except AttributeError as e: - print(f"Skipping {original_name}", name, e) - continue - return model - - class ElectraEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -637,15 +554,12 @@ def forward(self, generator_hidden_states): @auto_docstring class ElectraPreTrainedModel(PreTrainedModel): config: ElectraConfig - load_tf_weights = load_tf_weights_in_electra base_model_prefix = "electra" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1582,5 +1496,4 @@ def forward( "ElectraForTokenClassification", "ElectraModel", "ElectraPreTrainedModel", - "load_tf_weights_in_electra", ] diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index ddcf460f01ee..21015e50bb2f 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -609,8 +609,6 @@ class EsmPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py index 1dadc6f5377b..91c6990b77b9 100644 --- a/src/transformers/models/flaubert/modeling_flaubert.py +++ b/src/transformers/models/flaubert/modeling_flaubert.py @@ -676,7 +676,6 @@ def forward( # Copied from transformers.models.xlm.modeling_xlm.XLMPreTrainedModel with XLM->Flaubert class FlaubertPreTrainedModel(PreTrainedModel): config: FlaubertConfig - load_tf_weights = None base_model_prefix = "transformer" def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/funnel/modeling_funnel.py b/src/transformers/models/funnel/modeling_funnel.py index 4370344cccfb..d782be0856c8 100644 --- a/src/transformers/models/funnel/modeling_funnel.py +++ b/src/transformers/models/funnel/modeling_funnel.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch Funnel Transformer model.""" -import os from dataclasses import dataclass from typing import Optional, Union @@ -43,96 +42,6 @@ INF = 1e6 -def load_tf_weights_in_funnel(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - _layer_map = { - "k": "k_head", - "q": "q_head", - "v": "v_head", - "o": "post_proj", - "layer_1": "linear_1", - "layer_2": "linear_2", - "rel_attn": "attention", - "ff": "ffn", - "kernel": "weight", - "gamma": "weight", - "beta": "bias", - "lookup_table": "weight", - "word_embedding": "word_embeddings", - "input": "embeddings", - } - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - if name[0] == "generator": - continue - pointer = model - skipped = False - for m_name in name[1:]: - if not isinstance(pointer, FunnelPositionwiseFFN) and re.fullmatch(r"layer_\d+", m_name): - layer_index = int(re.search(r"layer_(\d+)", m_name).groups()[0]) - if layer_index < config.num_hidden_layers: - block_idx = 0 - while layer_index >= config.block_sizes[block_idx]: - layer_index -= config.block_sizes[block_idx] - block_idx += 1 - pointer = pointer.blocks[block_idx][layer_index] - else: - layer_index -= config.num_hidden_layers - pointer = pointer.layers[layer_index] - elif m_name == "r" and isinstance(pointer, FunnelRelMultiheadAttention): - pointer = pointer.r_kernel - break - elif m_name in _layer_map: - pointer = getattr(pointer, _layer_map[m_name]) - else: - try: - pointer = getattr(pointer, m_name) - except AttributeError: - print(f"Skipping {'/'.join(name)}", array.shape) - skipped = True - break - if not skipped: - if len(pointer.shape) != len(array.shape): - array = array.reshape(pointer.shape) - if m_name == "kernel": - array = np.transpose(array) - pointer.data = torch.from_numpy(array) - - return model - - class FunnelEmbeddings(nn.Module): def __init__(self, config: FunnelConfig) -> None: super().__init__() @@ -761,7 +670,6 @@ def forward(self, discriminator_hidden_states: torch.Tensor) -> torch.Tensor: @auto_docstring class FunnelPreTrainedModel(PreTrainedModel): config: FunnelConfig - load_tf_weights = load_tf_weights_in_funnel base_model_prefix = "funnel" def _init_weights(self, module): @@ -1448,5 +1356,4 @@ def forward( "FunnelForTokenClassification", "FunnelModel", "FunnelPreTrainedModel", - "load_tf_weights_in_funnel", ] diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index ae0786179464..1cbea3f50da8 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -16,7 +16,6 @@ """PyTorch OpenAI GPT-2 model.""" import math -import os import warnings from dataclasses import dataclass from typing import Callable, Optional, Union @@ -54,62 +53,6 @@ logger = logging.get_logger(__name__) -def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): - """Load tf checkpoints in a pytorch model""" - try: - import re - - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(gpt2_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array.squeeze()) - - for name, array in zip(names, arrays): - name = name[6:] # skip "model/" - name = name.split("/") - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+\d+", m_name): - scope_names = re.split(r"(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "w" or scope_names[0] == "g": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "b": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "wpe" or scope_names[0] == "wte": - pointer = getattr(pointer, scope_names[0]) - pointer = getattr(pointer, "weight") - else: - pointer = getattr(pointer, scope_names[0]) - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except ValueError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs): attn_weights = torch.matmul(query, key.transpose(-1, -2)) @@ -562,7 +505,6 @@ def forward( @auto_docstring class GPT2PreTrainedModel(PreTrainedModel): config: GPT2Config - load_tf_weights = load_tf_weights_in_gpt2 base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True @@ -580,8 +522,6 @@ def __init__(self, *inputs, **kwargs): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1634,5 +1574,4 @@ def forward( "GPT2LMHeadModel", "GPT2Model", "GPT2PreTrainedModel", - "load_tf_weights_in_gpt2", ] diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 69d74565745a..5d1d1beb0405 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch GPT Neo model.""" -import os from typing import Optional, Union import torch @@ -63,86 +62,6 @@ logger = logging.get_logger(__name__) -def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): - """Load tf checkpoints in a pytorch model""" - try: - import re - - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(gpt_neo_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - if "global_step" not in name and "adam" not in name: - array = tf.train.load_variable(tf_path, name) - array = tf.dtypes.cast(array.squeeze(), tf.float32).numpy() - name = name.replace("attn/q", "attn/attention/q_proj/w") - name = name.replace("attn/k", "attn/attention/k_proj/w") - name = name.replace("attn/v", "attn/attention/v_proj/w") - name = name.replace("attn/o", "attn/attention/out_proj/w") - name = name.replace("norm_1", "ln_1") - name = name.replace("norm_2", "ln_2") - name = name.replace("attn/compute_output_bias/o_b", "attn/attention/out_proj/b") - name = name.replace("conv1d_main/c_fc/kernel", "c_fc/w") - name = name.replace("conv1d_main/c_fc/bias", "c_fc/b") - name = name.replace("conv1d_main/c_proj/kernel", "c_proj/w") - name = name.replace("conv1d_main/c_proj/bias", "c_proj/b") - - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name[5:] # skip "gpt2/" - name = name.split("/") - pointer = model.transformer - for m_name in name: - if re.fullmatch(r"[A-Za-z]+\d+", m_name): - scope_names = re.split(r"(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "w" or scope_names[0] == "g": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "b": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "wpe" or scope_names[0] == "wte": - pointer = getattr(pointer, scope_names[0]) - pointer = getattr(pointer, "weight") - else: - pointer = getattr(pointer, scope_names[0]) - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - - if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]: - array = array.transpose() - - if name == ["wte"]: - # if vocab is padded, then trim off the padding embeddings - array = array[: config.vocab_size] - - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched {name}") - - print(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - - # init the final linear layer using word embeddings - embs = model.transformer.wte.weight - lin = nn.Linear(embs.size()[1], embs.size()[0], bias=False) - lin.weight = embs - model.set_output_embeddings(lin) - return model - - class GPTNeoSelfAttention(nn.Module): def __init__(self, config, attention_type, layer_id=None): super().__init__() @@ -470,7 +389,6 @@ def forward( @auto_docstring class GPTNeoPreTrainedModel(PreTrainedModel): config: GPTNeoConfig - load_tf_weights = load_tf_weights_in_gpt_neo base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoBlock"] @@ -484,8 +402,6 @@ def __init__(self, *inputs, **kwargs): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear,)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -843,8 +759,6 @@ def forward( if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) - # Compute loss in fp32 to match with mesh-tf version - # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 lm_logits = lm_logits.to(torch.float32) # Flatten the tokens @@ -1188,5 +1102,4 @@ def forward( "GPTNeoForTokenClassification", "GPTNeoModel", "GPTNeoPreTrainedModel", - "load_tf_weights_in_gpt_neo", ] diff --git a/src/transformers/models/grounding_dino/image_processing_grounding_dino.py b/src/transformers/models/grounding_dino/image_processing_grounding_dino.py index 24e3c8b3f987..2910ea471059 100644 --- a/src/transformers/models/grounding_dino/image_processing_grounding_dino.py +++ b/src/transformers/models/grounding_dino/image_processing_grounding_dino.py @@ -54,11 +54,7 @@ from ...utils import ( ExplicitEnum, TensorType, - is_flax_available, - is_jax_tensor, is_scipy_available, - is_tf_available, - is_tf_tensor, is_torch_available, is_torch_tensor, is_vision_available, @@ -210,18 +206,10 @@ def get_numpy_to_framework_fn(arr) -> Callable: """ if isinstance(arr, np.ndarray): return np.array - if is_tf_available() and is_tf_tensor(arr): - import tensorflow as tf - - return tf.convert_to_tensor if is_torch_available() and is_torch_tensor(arr): import torch return torch.tensor - if is_flax_available() and is_jax_tensor(arr): - import jax.numpy as jnp - - return jnp.array raise ValueError(f"Cannot convert arrays of type {type(arr)}") @@ -1241,10 +1229,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -1429,10 +1415,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor.") validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index 59c8078ad84a..5ab7e480c8ea 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -29,15 +29,13 @@ Unpack, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import is_tf_available, is_torch_available +from ...utils import is_torch_available from ...utils.deprecation import deprecate_kwarg if is_torch_available(): import torch -if is_tf_available(): - import tensorflow as tf IMAGE_TOKEN = "" @@ -74,8 +72,6 @@ def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_c if num_classes != -1: if return_tensors == "pt": incremental_mask[incremental_mask >= num_classes] = -1 - elif return_tensors == "tf": - incremental_mask = tf.where(incremental_mask >= num_classes, -1, incremental_mask) # Create mask for negative values if return_tensors == "pt": @@ -83,13 +79,6 @@ def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_c incremental_mask[negatives] = 0 attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes) attn_mask[negatives, :] = 0 - elif return_tensors == "tf": - negatives = tf.equal(incremental_mask, -1) - incremental_mask = tf.where(negatives, 0, incremental_mask) - attn_mask = tf.one_hot(incremental_mask, depth=num_classes) - # Reshape 'negatives' to add an extra dimension, making it [batch_size, seq_length, 1] - negatives_expanded = tf.expand_dims(negatives, -1) - attn_mask = tf.where(negatives_expanded, tf.zeros_like(attn_mask), attn_mask) return attn_mask @@ -98,8 +87,6 @@ def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_c def image_attention_mask_for_packed_input_ids(input_ids, tokenizer, return_tensors): if return_tensors == "pt": return image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer) - elif return_tensors == "tf": - return image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer) def image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer): @@ -149,39 +136,6 @@ def image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer): return image_attention_mask, next_image_attention_mask -def image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer): - image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - eod_token_id = tokenizer.eos_token_id - batch_size = tf.shape(input_ids)[0] - image_attention_mask = tf.fill(tf.shape(input_ids), -1) - next_image_attention_mask = tf.fill(tf.shape(input_ids), -1) - - for batch_idx in range(batch_size): - count = -1 - seen_eod = False - seq_length = tf.shape(input_ids)[1] - - for idx in range(seq_length - 1, -1, -1): - token_id = input_ids[batch_idx, idx].numpy() - if token_id == image_token_id: - count += 1 - indices = [[batch_idx, idx]] - updates = [count] - image_attention_mask = tf.tensor_scatter_nd_update(image_attention_mask, indices, updates) - next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates) - elif token_id == eod_token_id and not seen_eod: - seen_eod = True - count = 0 - indices = [[batch_idx, idx]] - updates = [count] - next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates) - if seen_eod and token_id != eod_token_id: - indices = [[batch_idx, idx]] - updates = [-1] - next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates) - return image_attention_mask, next_image_attention_mask - - def is_url(string): """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately invalidated the url""" @@ -451,42 +405,19 @@ def image_tokens(last_was_image): if return_tensors == "pt": padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) padded_image_tensor[: current_images.size(0)] = current_images - elif return_tensors == "tf": - # Assuming current_images is a TensorFlow tensor - # Get the shape of current_images, excluding the first dimension - image_shape = tf.shape(current_images)[1:] - # Create a shape for the padded_image_tensor - padded_shape = tf.concat([[max_num_images], image_shape], axis=0) - # Create the padded_image_tensor of zeros - padded_image_tensor = tf.zeros(padded_shape, dtype=current_images.dtype) - # Get the number of images (assuming current_images has shape [num_images, height, width, channels]) - num_images = tf.shape(current_images)[0] - # Update the padded_image_tensor with the values from current_images - indices = tf.reshape(tf.range(num_images), (-1, 1)) - updates = current_images - padded_image_tensor = tf.tensor_scatter_nd_update(padded_image_tensor, indices, updates) else: if return_tensors == "pt": padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims) - elif return_tensors == "tf": - padded_image_tensor = tf.zeros((max_num_images, *self.default_image_dims)) output_images.append(padded_image_tensor) if return_tensors == "pt": output_input_ids.append(torch.tensor(padded_input_ids)) output_attention_masks.append(torch.tensor(attention_mask)) - elif return_tensors == "tf": - output_input_ids.append(tf.convert_to_tensor(padded_input_ids, dtype=tf.int32)) - output_attention_masks.append(attention_mask) if return_tensors == "pt": output_input_ids = torch.stack(output_input_ids) output_images = torch.stack(output_images) output_attention_masks = torch.stack(output_attention_masks) - elif return_tensors == "tf": - output_input_ids = tf.stack(output_input_ids) - output_images = tf.stack(output_images) - output_attention_masks = tf.stack(output_attention_masks) if at_least_one_image: image_attention_mask, _ = image_attention_mask_for_packed_input_ids( @@ -501,10 +432,6 @@ def image_tokens(last_was_image): image_attention_mask = torch.zeros( output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool ) - elif return_tensors == "tf": - image_attention_mask = tf.zeros( - (output_input_ids.shape[0], output_input_ids.shape[1], 1), dtype=tf.bool - ) return BatchFeature( data={ "input_ids": output_input_ids, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index a962141e4479..cd527d28e5d1 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -15,7 +15,6 @@ """PyTorch OpenAI ImageGPT model.""" import math -import os from typing import Any, Optional, Union import torch @@ -44,114 +43,6 @@ logger = logging.get_logger(__name__) -def load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path): - """ - Load tf checkpoints in a pytorch model - """ - try: - import re - - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(imagegpt_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array.squeeze()) - - for name, array in zip(names, arrays): - name = name[6:] # skip "model/" - name = name.split("/") - - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ) or name[-1] in ["_step"]: - logger.info("Skipping {}".format("/".join(name))) - continue - - pointer = model - if name[-1] not in ["wtet"]: - pointer = getattr(pointer, "transformer") - - for m_name in name: - if re.fullmatch(r"[A-Za-z]+\d+", m_name): - scope_names = re.split(r"(\d+)", m_name) - else: - scope_names = [m_name] - - if scope_names[0] == "w" or scope_names[0] == "g": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "b": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "wpe" or scope_names[0] == "wte": - pointer = getattr(pointer, scope_names[0]) - pointer = getattr(pointer, "weight") - elif scope_names[0] in ["q_proj", "k_proj", "v_proj"]: - pointer = getattr(pointer, "c_attn") - pointer = getattr(pointer, "weight") - elif len(name) == 3 and name[1] == "attn" and scope_names[0] == "c_proj": - pointer = getattr(pointer, scope_names[0]) - pointer = getattr(pointer, "weight") - elif scope_names[0] == "wtet": - pointer = getattr(pointer, "lm_head") - pointer = getattr(pointer, "weight") - elif scope_names[0] == "sos": - pointer = getattr(pointer, "wte") - pointer = getattr(pointer, "weight") - else: - pointer = getattr(pointer, scope_names[0]) - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - - if len(name) > 1 and name[1] == "attn" or name[-1] == "wtet" or name[-1] == "sos" or name[-1] == "wte": - pass # array is used to initialize only part of the pointer so sizes won't match - else: - try: - assert pointer.shape == array.shape - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - - logger.info(f"Initialize PyTorch weight {name}") - - if name[-1] == "q_proj": - pointer.data[:, : config.n_embd] = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)).T - elif name[-1] == "k_proj": - pointer.data[:, config.n_embd : 2 * config.n_embd] = torch.from_numpy( - array.reshape(config.n_embd, config.n_embd) - ).T - elif name[-1] == "v_proj": - pointer.data[:, 2 * config.n_embd :] = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)).T - elif len(name) == 3 and name[1] == "attn" and name[2] == "c_proj": - pointer.data = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)) - elif name[-1] == "wtet": - pointer.data = torch.from_numpy(array) - elif name[-1] == "wte": - pointer.data[: config.vocab_size - 1, :] = torch.from_numpy(array) - elif name[-1] == "sos": - pointer.data[-1] = torch.from_numpy(array) - else: - pointer.data = torch.from_numpy(array) - - return model - - class ImageGPTLayerNorm(nn.Module): def __init__(self, hidden_size: tuple[int], eps: float = 1e-5): super().__init__() @@ -498,7 +389,6 @@ def forward( @auto_docstring class ImageGPTPreTrainedModel(PreTrainedModel): config: ImageGPTConfig - load_tf_weights = load_tf_weights_in_imagegpt base_model_prefix = "transformer" main_input_name = "input_ids" supports_gradient_checkpointing = True @@ -510,8 +400,6 @@ def __init__(self, *inputs, **kwargs): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1020,5 +908,4 @@ def forward( "ImageGPTForImageClassification", "ImageGPTModel", "ImageGPTPreTrainedModel", - "load_tf_weights_in_imagegpt", ] diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index 00243ce12329..552508e76974 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -15,7 +15,6 @@ """PyTorch LXMERT model.""" import math -import os import warnings from dataclasses import dataclass from typing import Optional, Union @@ -179,85 +178,6 @@ class LxmertForPreTrainingOutput(ModelOutput): cross_encoder_attentions: Optional[tuple[torch.FloatTensor]] = None -def load_tf_weights_in_lxmert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n - in [ - "adam_v", - "adam_m", - "AdamWeightDecayOptimizer", - "AdamWeightDecayOptimizer_1", - "global_step", - ] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - assert pointer.shape == array.shape - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class LxmertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -760,15 +680,12 @@ def forward(self, sequence_output, pooled_output): @auto_docstring class LxmertPreTrainedModel(PreTrainedModel): config: LxmertConfig - load_tf_weights = load_tf_weights_in_lxmert base_model_prefix = "lxmert" _supports_param_buffer_assignment = False def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 78fbf8f215aa..61acac83b0f2 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -553,8 +553,6 @@ class MarkupLMPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index a75c0f575aca..d6a45bbd8eb8 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -16,7 +16,6 @@ """PyTorch MegatronBERT model.""" import math -import os import warnings from dataclasses import dataclass from typing import Optional, Union @@ -50,75 +49,6 @@ logger = logging.get_logger(__name__) -def load_tf_weights_in_megatron_bert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class MegatronBertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -672,15 +602,12 @@ def forward(self, sequence_output, pooled_output): @auto_docstring class MegatronBertPreTrainedModel(PreTrainedModel): config: MegatronBertConfig - load_tf_weights = load_tf_weights_in_megatron_bert base_model_prefix = "bert" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if hasattr(module, "bias") and module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 99768685002a..1d6c5f7c46f4 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -21,7 +21,6 @@ # SOFTWARE. import math -import os import warnings from dataclasses import dataclass from typing import Optional, Union @@ -50,84 +49,6 @@ logger = logging.get_logger(__name__) -def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.replace("ffn_layer", "ffn") - name = name.replace("FakeLayerNorm", "LayerNorm") - name = name.replace("extra_output_weights", "dense/kernel") - name = name.replace("bert", "mobilebert") - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - assert pointer.shape == array.shape, ( - f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - ) - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class NoNorm(nn.Module): def __init__(self, feat_size, eps=None): super().__init__() @@ -659,14 +580,11 @@ def forward(self, sequence_output: torch.Tensor, pooled_output: torch.Tensor) -> @auto_docstring class MobileBertPreTrainedModel(PreTrainedModel): config: MobileBertConfig - load_tf_weights = load_tf_weights_in_mobilebert base_model_prefix = "mobilebert" def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1478,5 +1396,4 @@ def forward( "MobileBertLayer", "MobileBertModel", "MobileBertPreTrainedModel", - "load_tf_weights_in_mobilebert", ] diff --git a/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py b/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py index 25997a46790c..f80c6977bf18 100755 --- a/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +++ b/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py @@ -29,109 +29,6 @@ logger = logging.get_logger(__name__) -def _build_tf_to_pytorch_map(model, config, tf_weights=None): - """ - A map of modules from TF to PyTorch. - """ - - tf_to_pt_map = {} - - if isinstance(model, MobileNetV1ForImageClassification): - backbone = model.mobilenet_v1 - else: - backbone = model - - prefix = "MobilenetV1/Conv2d_0/" - tf_to_pt_map[prefix + "weights"] = backbone.conv_stem.convolution.weight - tf_to_pt_map[prefix + "BatchNorm/beta"] = backbone.conv_stem.normalization.bias - tf_to_pt_map[prefix + "BatchNorm/gamma"] = backbone.conv_stem.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.normalization.running_var - - for i in range(13): - tf_index = i + 1 - pt_index = i * 2 - - pointer = backbone.layer[pt_index] - prefix = f"MobilenetV1/Conv2d_{tf_index}_depthwise/" - tf_to_pt_map[prefix + "depthwise_weights"] = pointer.convolution.weight - tf_to_pt_map[prefix + "BatchNorm/beta"] = pointer.normalization.bias - tf_to_pt_map[prefix + "BatchNorm/gamma"] = pointer.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.normalization.running_var - - pointer = backbone.layer[pt_index + 1] - prefix = f"MobilenetV1/Conv2d_{tf_index}_pointwise/" - tf_to_pt_map[prefix + "weights"] = pointer.convolution.weight - tf_to_pt_map[prefix + "BatchNorm/beta"] = pointer.normalization.bias - tf_to_pt_map[prefix + "BatchNorm/gamma"] = pointer.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.normalization.running_var - - if isinstance(model, MobileNetV1ForImageClassification): - prefix = "MobilenetV1/Logits/Conv2d_1c_1x1/" - tf_to_pt_map[prefix + "weights"] = model.classifier.weight - tf_to_pt_map[prefix + "biases"] = model.classifier.bias - - return tf_to_pt_map - - -def load_tf_weights_in_mobilenet_v1(model, config, tf_checkpoint_path): - """Load TensorFlow checkpoints in a PyTorch model.""" - try: - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - - # Load weights from TF model - init_vars = tf.train.list_variables(tf_checkpoint_path) - tf_weights = {} - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_checkpoint_path, name) - tf_weights[name] = array - - # Build TF to PyTorch weights loading map - tf_to_pt_map = _build_tf_to_pytorch_map(model, config, tf_weights) - - for name, pointer in tf_to_pt_map.items(): - logger.info(f"Importing {name}") - if name not in tf_weights: - logger.info(f"{name} not in tf pre-trained weights, skipping") - continue - - array = tf_weights[name] - - if "depthwise_weights" in name: - logger.info("Transposing depthwise") - array = np.transpose(array, (2, 3, 0, 1)) - elif "weights" in name: - logger.info("Transposing") - if len(pointer.shape) == 2: # copying into linear layer - array = array.squeeze().transpose() - else: - array = np.transpose(array, (3, 2, 0, 1)) - - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - - logger.info(f"Initialize PyTorch weight {name} {array.shape}") - pointer.data = torch.from_numpy(array) - - tf_weights.pop(name, None) - tf_weights.pop(name + "/RMSProp", None) - tf_weights.pop(name + "/RMSProp_1", None) - tf_weights.pop(name + "/ExponentialMovingAverage", None) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}") - return model - - def apply_tf_padding(features: torch.Tensor, conv_layer: nn.Conv2d) -> torch.Tensor: """ Apply TensorFlow-style "SAME" padding to a convolution layer. See the notes at: @@ -229,7 +126,6 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: @auto_docstring class MobileNetV1PreTrainedModel(PreTrainedModel): config: MobileNetV1Config - load_tf_weights = load_tf_weights_in_mobilenet_v1 base_model_prefix = "mobilenet_v1" main_input_name = "pixel_values" supports_gradient_checkpointing = False @@ -410,5 +306,4 @@ def forward( "MobileNetV1ForImageClassification", "MobileNetV1Model", "MobileNetV1PreTrainedModel", - "load_tf_weights_in_mobilenet_v1", ] diff --git a/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py index 8f178f0480dd..d5c5f2e7fdb9 100755 --- a/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py @@ -34,175 +34,6 @@ logger = logging.get_logger(__name__) -def _build_tf_to_pytorch_map(model, config, tf_weights=None): - """ - A map of modules from TF to PyTorch. - """ - - tf_to_pt_map = {} - - if isinstance(model, (MobileNetV2ForImageClassification, MobileNetV2ForSemanticSegmentation)): - backbone = model.mobilenet_v2 - else: - backbone = model - - # Use the EMA weights if available - def ema(x): - return x + "/ExponentialMovingAverage" if x + "/ExponentialMovingAverage" in tf_weights else x - - prefix = "MobilenetV2/Conv/" - tf_to_pt_map[ema(prefix + "weights")] = backbone.conv_stem.first_conv.convolution.weight - tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_stem.first_conv.normalization.bias - tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_stem.first_conv.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.first_conv.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.first_conv.normalization.running_var - - prefix = "MobilenetV2/expanded_conv/depthwise/" - tf_to_pt_map[ema(prefix + "depthwise_weights")] = backbone.conv_stem.conv_3x3.convolution.weight - tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_stem.conv_3x3.normalization.bias - tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_stem.conv_3x3.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.conv_3x3.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.conv_3x3.normalization.running_var - - prefix = "MobilenetV2/expanded_conv/project/" - tf_to_pt_map[ema(prefix + "weights")] = backbone.conv_stem.reduce_1x1.convolution.weight - tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_stem.reduce_1x1.normalization.bias - tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_stem.reduce_1x1.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.reduce_1x1.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.reduce_1x1.normalization.running_var - - for i in range(16): - tf_index = i + 1 - pt_index = i - pointer = backbone.layer[pt_index] - - prefix = f"MobilenetV2/expanded_conv_{tf_index}/expand/" - tf_to_pt_map[ema(prefix + "weights")] = pointer.expand_1x1.convolution.weight - tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.expand_1x1.normalization.bias - tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.expand_1x1.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.expand_1x1.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.expand_1x1.normalization.running_var - - prefix = f"MobilenetV2/expanded_conv_{tf_index}/depthwise/" - tf_to_pt_map[ema(prefix + "depthwise_weights")] = pointer.conv_3x3.convolution.weight - tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.conv_3x3.normalization.bias - tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.conv_3x3.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.conv_3x3.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.conv_3x3.normalization.running_var - - prefix = f"MobilenetV2/expanded_conv_{tf_index}/project/" - tf_to_pt_map[ema(prefix + "weights")] = pointer.reduce_1x1.convolution.weight - tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.reduce_1x1.normalization.bias - tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.reduce_1x1.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.reduce_1x1.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.reduce_1x1.normalization.running_var - - prefix = "MobilenetV2/Conv_1/" - tf_to_pt_map[ema(prefix + "weights")] = backbone.conv_1x1.convolution.weight - tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_1x1.normalization.bias - tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_1x1.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_1x1.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_1x1.normalization.running_var - - if isinstance(model, MobileNetV2ForImageClassification): - prefix = "MobilenetV2/Logits/Conv2d_1c_1x1/" - tf_to_pt_map[ema(prefix + "weights")] = model.classifier.weight - tf_to_pt_map[ema(prefix + "biases")] = model.classifier.bias - - if isinstance(model, MobileNetV2ForSemanticSegmentation): - prefix = "image_pooling/" - tf_to_pt_map[prefix + "weights"] = model.segmentation_head.conv_pool.convolution.weight - tf_to_pt_map[prefix + "BatchNorm/beta"] = model.segmentation_head.conv_pool.normalization.bias - tf_to_pt_map[prefix + "BatchNorm/gamma"] = model.segmentation_head.conv_pool.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = model.segmentation_head.conv_pool.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = ( - model.segmentation_head.conv_pool.normalization.running_var - ) - - prefix = "aspp0/" - tf_to_pt_map[prefix + "weights"] = model.segmentation_head.conv_aspp.convolution.weight - tf_to_pt_map[prefix + "BatchNorm/beta"] = model.segmentation_head.conv_aspp.normalization.bias - tf_to_pt_map[prefix + "BatchNorm/gamma"] = model.segmentation_head.conv_aspp.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = model.segmentation_head.conv_aspp.normalization.running_mean - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = ( - model.segmentation_head.conv_aspp.normalization.running_var - ) - - prefix = "concat_projection/" - tf_to_pt_map[prefix + "weights"] = model.segmentation_head.conv_projection.convolution.weight - tf_to_pt_map[prefix + "BatchNorm/beta"] = model.segmentation_head.conv_projection.normalization.bias - tf_to_pt_map[prefix + "BatchNorm/gamma"] = model.segmentation_head.conv_projection.normalization.weight - tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = ( - model.segmentation_head.conv_projection.normalization.running_mean - ) - tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = ( - model.segmentation_head.conv_projection.normalization.running_var - ) - - prefix = "logits/semantic/" - tf_to_pt_map[ema(prefix + "weights")] = model.segmentation_head.classifier.convolution.weight - tf_to_pt_map[ema(prefix + "biases")] = model.segmentation_head.classifier.convolution.bias - - return tf_to_pt_map - - -def load_tf_weights_in_mobilenet_v2(model, config, tf_checkpoint_path): - """Load TensorFlow checkpoints in a PyTorch model.""" - try: - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - - # Load weights from TF model - init_vars = tf.train.list_variables(tf_checkpoint_path) - tf_weights = {} - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_checkpoint_path, name) - tf_weights[name] = array - - # Build TF to PyTorch weights loading map - tf_to_pt_map = _build_tf_to_pytorch_map(model, config, tf_weights) - - for name, pointer in tf_to_pt_map.items(): - logger.info(f"Importing {name}") - if name not in tf_weights: - logger.info(f"{name} not in tf pre-trained weights, skipping") - continue - - array = tf_weights[name] - - if "depthwise_weights" in name: - logger.info("Transposing depthwise") - array = np.transpose(array, (2, 3, 0, 1)) - elif "weights" in name: - logger.info("Transposing") - if len(pointer.shape) == 2: # copying into linear layer - array = array.squeeze().transpose() - else: - array = np.transpose(array, (3, 2, 0, 1)) - - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - - logger.info(f"Initialize PyTorch weight {name} {array.shape}") - pointer.data = torch.from_numpy(array) - - tf_weights.pop(name, None) - tf_weights.pop(name + "/RMSProp", None) - tf_weights.pop(name + "/RMSProp_1", None) - tf_weights.pop(name + "/ExponentialMovingAverage", None) - tf_weights.pop(name + "/Momentum", None) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}") - return model - - def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int: """ Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the @@ -423,7 +254,6 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: @auto_docstring class MobileNetV2PreTrainedModel(PreTrainedModel): config: MobileNetV2Config - load_tf_weights = load_tf_weights_in_mobilenet_v2 base_model_prefix = "mobilenet_v2" main_input_name = "pixel_values" supports_gradient_checkpointing = False @@ -782,5 +612,4 @@ def forward( "MobileNetV2ForSemanticSegmentation", "MobileNetV2Model", "MobileNetV2PreTrainedModel", - "load_tf_weights_in_mobilenet_v2", ] diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 4e57d0aadda2..ca3851218ccf 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -16,7 +16,6 @@ import copy import math -import os import warnings from typing import Optional, Union @@ -630,112 +629,6 @@ def forward( ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) -def load_tf_weights_in_mt5(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - tf_weights = {} - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - tf_weights[name] = array - - for txt_name in names: - name = txt_name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - tf_weights.pop(txt_name, None) - continue - if "_slot_" in name[-1]: - logger.info(f"Skipping {'/'.join(name)}") - tf_weights.pop(txt_name, None) - continue - pointer = model - array = tf_weights[txt_name] - - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") - elif scope_names[0] == "self_attention": - pointer = getattr(pointer, "layer") - pointer = pointer[0] - elif scope_names[0] == "enc_dec_attention": - pointer = getattr(pointer, "layer") - pointer = pointer[1] - elif scope_names[0] == "dense_relu_dense": - pointer = getattr(pointer, "layer") - pointer = pointer[2] - elif scope_names[0] == "rms_norm": - if hasattr(pointer, "layer_norm"): - pointer = getattr(pointer, "layer_norm") - elif hasattr(pointer, "final_layer_norm"): - pointer = getattr(pointer, "final_layer_norm") - elif scope_names[0] == "scale": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - elif scope_names[0] == "decoder" and name[1] == "logits": - continue - elif scope_names[0] == "logits": - pointer = getattr(pointer, "lm_head") - elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): - pointer = getattr(pointer, f"wi_{scope_names[1]}") - continue - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if scope_names[0] not in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") - if scope_names[0] != "embedding": - logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") - array = np.transpose(array) - try: - assert pointer.shape == array.shape, ( - f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - ) - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array.astype(np.float32)) - tf_weights.pop(txt_name, None) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") - return model - - # Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->MT5 class MT5ClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" diff --git a/src/transformers/models/owlv2/processing_owlv2.py b/src/transformers/models/owlv2/processing_owlv2.py index 2e69379af73f..271bea054931 100644 --- a/src/transformers/models/owlv2/processing_owlv2.py +++ b/src/transformers/models/owlv2/processing_owlv2.py @@ -30,7 +30,7 @@ Unpack, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available +from ...utils import TensorType, is_torch_available if TYPE_CHECKING: @@ -105,10 +105,8 @@ def __call__( should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -157,24 +155,11 @@ def __call__( input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) - elif return_tensors == "jax" and is_flax_available(): - import jax.numpy as jnp - - input_ids = jnp.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) - attention_mask = jnp.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) - elif return_tensors == "pt" and is_torch_available(): import torch input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0) attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0) - - elif return_tensors == "tf" and is_tf_available(): - import tensorflow as tf - - input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0) - attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0) - else: raise ValueError("Target return tensor type could not be returned") diff --git a/src/transformers/models/owlvit/processing_owlvit.py b/src/transformers/models/owlvit/processing_owlvit.py index 0e0c59d555f2..08f19924e80b 100644 --- a/src/transformers/models/owlvit/processing_owlvit.py +++ b/src/transformers/models/owlvit/processing_owlvit.py @@ -30,7 +30,7 @@ Unpack, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available +from ...utils import TensorType, is_torch_available if TYPE_CHECKING: @@ -115,10 +115,8 @@ def __call__( should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -166,25 +164,11 @@ def __call__( if return_tensors == "np": input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) - - elif return_tensors == "jax" and is_flax_available(): - import jax.numpy as jnp - - input_ids = jnp.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) - attention_mask = jnp.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) - elif return_tensors == "pt" and is_torch_available(): import torch input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0) attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0) - - elif return_tensors == "tf" and is_tf_available(): - import tensorflow as tf - - input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0) - attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0) - else: raise ValueError("Target return tensor type could not be returned") diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 0fc9635cda88..26c3c693245f 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -15,7 +15,6 @@ """PyTorch RemBERT model.""" import math -import os from typing import Optional, Union import torch @@ -46,89 +45,6 @@ logger = logging.get_logger(__name__) -def load_tf_weights_in_rembert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - # Checkpoint is 12Gb, save memory by not loading useless variables - # Output embedding and cls are reset at classification time - if any(deny in name for deny in ("adam_v", "adam_m", "output_embedding", "cls")): - # logger.info("Skipping loading of %s", name) - continue - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - # Replace prefix with right one - name = name.replace("bert/", "rembert/") - # The pooler is a linear layer - # name = name.replace("pooler/dense", "pooler") - - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info("Skipping {}".format("/".join(name))) - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class RemBertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -614,15 +530,12 @@ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: @auto_docstring class RemBertPreTrainedModel(PreTrainedModel): config: RemBertConfig - load_tf_weights = load_tf_weights_in_rembert base_model_prefix = "rembert" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1357,5 +1270,4 @@ def forward( "RemBertLayer", "RemBertModel", "RemBertPreTrainedModel", - "load_tf_weights_in_rembert", ] diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 33fb44118a90..4ba12384fbed 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -59,8 +59,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -672,8 +671,6 @@ class RobertaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 81481574b01e..e51214beac7c 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -59,8 +59,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -561,8 +560,6 @@ class RobertaPreLayerNormPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 22a72f91bc38..aba89b1c309d 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -15,7 +15,6 @@ """PyTorch RoCBert model.""" import math -import os from typing import Optional, Union import torch @@ -46,80 +45,6 @@ logger = logging.get_logger(__name__) -# Copied from transformers.models.bert.modeling_bert.load_tf_weights_in_bert with bert->roc_bert -def load_tf_weights_in_roc_bert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except ValueError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class RoCBertEmbeddings(nn.Module): """Construct the embeddings from word, position, shape, pronunciation and token_type embeddings.""" @@ -725,15 +650,12 @@ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: @auto_docstring class RoCBertPreTrainedModel(PreTrainedModel): config: RoCBertConfig - load_tf_weights = load_tf_weights_in_roc_bert base_model_prefix = "roc_bert" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -1929,5 +1851,4 @@ def forward( "RoCBertLayer", "RoCBertModel", "RoCBertPreTrainedModel", - "load_tf_weights_in_roc_bert", ] diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 3fc94cf87675..c9f59d1a0575 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -15,7 +15,6 @@ """PyTorch RoFormer model.""" import math -import os from typing import Callable, Optional, Union import numpy as np @@ -81,79 +80,6 @@ def forward( return super().forward(position_ids) -def load_tf_weights_in_roformer(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name.replace("bert", "roformer")) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if not pointer.shape == array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class RoFormerEmbeddings(nn.Module): """Construct the embeddings from word and token_type embeddings.""" @@ -768,8 +694,6 @@ class RoFormerPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr.py b/src/transformers/models/rt_detr/image_processing_rt_detr.py index de61a8019047..3c0e994c374a 100644 --- a/src/transformers/models/rt_detr/image_processing_rt_detr.py +++ b/src/transformers/models/rt_detr/image_processing_rt_detr.py @@ -50,10 +50,6 @@ ) from ...utils import ( filter_out_non_signature_kwargs, - is_flax_available, - is_jax_tensor, - is_tf_available, - is_tf_tensor, is_torch_available, is_torch_tensor, logging, @@ -184,18 +180,10 @@ def get_numpy_to_framework_fn(arr) -> Callable: """ if isinstance(arr, np.ndarray): return np.array - if is_tf_available() and is_tf_tensor(arr): - import tensorflow as tf - - return tf.convert_to_tensor if is_torch_available() and is_torch_tensor(arr): import torch return torch.tensor - if is_flax_available() and is_jax_tensor(arr): - import jax.numpy as jnp - - return jnp.array raise ValueError(f"Cannot convert arrays of type {type(arr)}") @@ -723,10 +711,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -896,10 +882,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor.") # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index c9b54f561fb6..4879d3655514 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -40,7 +40,6 @@ from ...utils import ( TensorType, filter_out_non_signature_kwargs, - is_tf_available, is_torch_available, is_torchvision_available, logging, @@ -55,12 +54,6 @@ if is_torchvision_available(): from torchvision.ops.boxes import batched_nms -if is_tf_available(): - import tensorflow as tf - from tensorflow.experimental import numpy as tnp - - from ...tf_utils import flatten, shape_list - logger = logging.get_logger(__name__) @@ -456,10 +449,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. @@ -497,18 +488,14 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor.") if segmentation_maps is not None: segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2) if not valid_images(segmentation_maps): raise ValueError( - "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor." ) validate_preprocess_arguments( do_rescale=do_rescale, @@ -588,12 +575,12 @@ def post_process_masks( Remove padding and upscale masks to the original image size. Args: - masks (`Union[list[torch.Tensor], list[np.ndarray], list[tf.Tensor]]`): + masks (`Union[list[torch.Tensor], list[np.ndarray]]`): Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. - original_sizes (`Union[torch.Tensor, tf.Tensor, list[tuple[int,int]]]`): + original_sizes (`Union[torch.Tensor, list[tuple[int,int]]]`): The original sizes of each image before it was resized to the model's expected input shape, in (height, width) format. - reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, list[tuple[int,int]]]`): + reshaped_input_sizes (`Union[torch.Tensor, list[tuple[int,int]]]`): The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. mask_threshold (`float`, *optional*, defaults to 0.0): The threshold to use for binarizing the masks. @@ -603,9 +590,9 @@ def post_process_masks( The target size the images were padded to before being passed to the model. If None, the target size is assumed to be the processor's `pad_size`. return_tensors (`str`, *optional*, defaults to `"pt"`): - If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors. + If `"pt"`, return PyTorch tensors. Returns: - (`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where + (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size. """ if return_tensors == "pt": @@ -617,17 +604,8 @@ def post_process_masks( binarize=binarize, pad_size=pad_size, ) - elif return_tensors == "tf": - return self._post_process_masks_tf( - masks=masks, - original_sizes=original_sizes, - reshaped_input_sizes=reshaped_input_sizes, - mask_threshold=mask_threshold, - binarize=binarize, - pad_size=pad_size, - ) else: - raise ValueError("return_tensors must be either 'pt' or 'tf'") + raise ValueError("return_tensors must be 'pt'") def _post_process_masks_pt( self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None @@ -676,48 +654,6 @@ def _post_process_masks_pt( return output_masks - def _post_process_masks_tf( - self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None - ): - """ - Remove padding and upscale masks to the original image size. - - Args: - masks (`tf.Tensor`): - Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. - original_sizes (`tf.Tensor`): - The original size of the images before resizing for input to the model, in (height, width) format. - reshaped_input_sizes (`tf.Tensor`): - The size of the image input to the model, in (height, width) format. Used to remove padding. - mask_threshold (`float`, *optional*, defaults to 0.0): - The threshold to use for binarizing the masks. - binarize (`bool`, *optional*, defaults to `True`): - Whether to binarize the masks. - pad_size (`int`, *optional*, defaults to `self.pad_size`): - The target size the images were padded to before being passed to the model. If None, the target size is - assumed to be the processor's `pad_size`. - Returns: - (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is - given by original_size. - """ - requires_backends(self, ["tf"]) - pad_size = self.pad_size if pad_size is None else pad_size - target_image_size = (pad_size["height"], pad_size["width"]) - - output_masks = [] - for i, original_size in enumerate(original_sizes): - # tf.image expects NHWC, we transpose the NCHW inputs for it - mask = tf.transpose(masks[i], perm=[0, 2, 3, 1]) - interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear") - interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :] - interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear") - if binarize: - interpolated_mask = interpolated_mask > mask_threshold - # And then we transpose them back at the end - output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2])) - - return output_masks - def post_process_for_mask_generation( self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt" ): @@ -725,21 +661,19 @@ def post_process_for_mask_generation( Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. Args: - all_masks (`Union[list[torch.Tensor], list[tf.Tensor]]`): + all_masks (`list[torch.Tensor]`): List of all predicted segmentation masks - all_scores (`Union[list[torch.Tensor], list[tf.Tensor]]`): + all_scores (`list[torch.Tensor]`): List of all predicted iou scores - all_boxes (`Union[list[torch.Tensor], list[tf.Tensor]]`): + all_boxes (`list[torch.Tensor]`): List of all bounding boxes of the predicted masks crops_nms_thresh (`float`): Threshold for NMS (Non Maximum Suppression) algorithm. return_tensors (`str`, *optional*, defaults to `pt`): - If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + If `pt`, returns `torch.Tensor`. """ if return_tensors == "pt": return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh) - elif return_tensors == "tf": - return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh) def generate_crop_boxes( self, @@ -776,7 +710,7 @@ def generate_crop_boxes( input_data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. return_tensors (`str`, *optional*, defaults to `pt`): - If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + If `pt`, returns `torch.Tensor`. """ crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes( image, @@ -795,15 +729,8 @@ def generate_crop_boxes( # cropped_images stays as np input_labels = torch.tensor(input_labels, device=device) - elif return_tensors == "tf": - if device is not None: - raise ValueError("device is not a supported argument when return_tensors is tf!") - crop_boxes = tf.convert_to_tensor(crop_boxes) - points_per_crop = tf.convert_to_tensor(points_per_crop) - # cropped_images stays as np - input_labels = tf.convert_to_tensor(input_labels) else: - raise ValueError("return_tensors must be either 'pt' or 'tf'.") + raise ValueError("return_tensors must be either `'pt'` or `None`") return crop_boxes, points_per_crop, cropped_images, input_labels def filter_masks( @@ -825,9 +752,9 @@ def filter_masks( bounding boxes and pad the predicted masks if necessary. Args: - masks (`Union[torch.Tensor, tf.Tensor]`): + masks (`torch.Tensor`): Input masks. - iou_scores (`Union[torch.Tensor, tf.Tensor]`): + iou_scores (`torch.Tensor`): List of IoU scores. original_size (`tuple[int,int]`): Size of the original image. @@ -842,7 +769,7 @@ def filter_masks( stability_score_offset (`float`, *optional*, defaults to 1): The offset for the stability score used in the `_compute_stability_score` method. return_tensors (`str`, *optional*, defaults to `pt`): - If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + If `pt`, returns `torch.Tensor`. """ if return_tensors == "pt": return self._filter_masks_pt( @@ -855,17 +782,6 @@ def filter_masks( mask_threshold=mask_threshold, stability_score_offset=stability_score_offset, ) - elif return_tensors == "tf": - return self._filter_masks_tf( - masks=masks, - iou_scores=iou_scores, - original_size=original_size, - cropped_box_image=cropped_box_image, - pred_iou_thresh=pred_iou_thresh, - stability_score_thresh=stability_score_thresh, - mask_threshold=mask_threshold, - stability_score_offset=stability_score_offset, - ) def _filter_masks_pt( self, @@ -947,83 +863,6 @@ def _filter_masks_pt( return masks, scores, converted_boxes - def _filter_masks_tf( - self, - masks, - iou_scores, - original_size, - cropped_box_image, - pred_iou_thresh=0.88, - stability_score_thresh=0.95, - mask_threshold=0, - stability_score_offset=1, - ): - """ - Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being - that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability - score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to - bounding boxes and pad the predicted masks if necessary. - - Args: - masks (`tf.Tensor`): - Input masks. - iou_scores (`tf.Tensor`): - List of IoU scores. - original_size (`tuple[int,int]`): - Size of the original image. - cropped_box_image (`np.array`): - The cropped image. - pred_iou_thresh (`float`, *optional*, defaults to 0.88): - The threshold for the iou scores. - stability_score_thresh (`float`, *optional*, defaults to 0.95): - The threshold for the stability score. - mask_threshold (`float`, *optional*, defaults to 0): - The threshold for the predicted masks. - stability_score_offset (`float`, *optional*, defaults to 1): - The offset for the stability score used in the `_compute_stability_score` method. - - """ - requires_backends(self, ["tf"]) - original_height, original_width = original_size - iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]]) - masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]]) - - if masks.shape[0] != iou_scores.shape[0]: - raise ValueError("masks and iou_scores must have the same batch size.") - - batch_size = masks.shape[0] - - keep_mask = tf.ones(batch_size, dtype=tf.bool) - - if pred_iou_thresh > 0.0: - keep_mask = keep_mask & (iou_scores > pred_iou_thresh) - - # compute stability score - if stability_score_thresh > 0.0: - stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset) - keep_mask = keep_mask & (stability_scores > stability_score_thresh) - - scores = iou_scores[keep_mask] - masks = masks[keep_mask] - - # binarize masks - masks = masks > mask_threshold - converted_boxes = _batched_mask_to_box_tf(masks) - - keep_mask = ~_is_box_near_crop_edge_tf( - converted_boxes, cropped_box_image, [0, 0, original_width, original_height] - ) - - scores = scores[keep_mask] - masks = masks[keep_mask] - converted_boxes = converted_boxes[keep_mask] - - masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width) - # conversion to rle is necessary to run non-maximum suppression - masks = _mask_to_rle_tf(masks) - - return masks, scores, converted_boxes - def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): # One mask is always contained inside the other. @@ -1036,17 +875,6 @@ def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, st return stability_scores -def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int): - # Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure - # we get the right division results. - intersections = tf.count_nonzero( - masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32 - ) - unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32) - stability_scores = intersections / unions - return stability_scores - - def _build_point_grid(n_per_side: int) -> np.ndarray: """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" offset = 1 / (2 * n_per_side) @@ -1215,16 +1043,6 @@ def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int): return torch.nn.functional.pad(masks, pad, value=0) -def _pad_masks_tf(masks, crop_box: list[int], orig_height: int, orig_width: int): - left, top, right, bottom = crop_box - if left == 0 and top == 0 and right == orig_width and bottom == orig_height: - return masks - # Coordinate transform masks - pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) - pad = (left, pad_x - left, top, pad_y - top) - return tf.pad(masks, pad, constant_values=0) - - def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): """Filter masks at the edge of a crop, but not at the edge of the original image.""" crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) @@ -1243,24 +1061,6 @@ def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): return torch.any(near_crop_edge, dim=1) -def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0): - """Filter masks at the edge of a crop, but not at the edge of the original image.""" - crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32) - orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32) - - left, top, _, _ = crop_box - offset = tf.convert_to_tensor([[left, top, left, top]]) - # Check if boxes has a channel dimension - if len(boxes.shape) == 3: - offset = tf.expand_dims(offset, 1) - boxes = tf.cast(boxes + offset, tf.float32) - - near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0) - near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0) - near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge) - return tf.reduce_any(near_crop_edge, axis=1) - - def _batched_mask_to_box(masks: "torch.Tensor"): """ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which @@ -1310,54 +1110,6 @@ def _batched_mask_to_box(masks: "torch.Tensor"): return out -def _batched_mask_to_box_tf(masks: "tf.Tensor"): - """ - Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which - corresponds the following required indices: - - LEFT: left hand side of the bounding box - - TOP: top of the bounding box - - RIGHT: right of the bounding box - - BOTTOM: bottom of the bounding box - - Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape - is channel_1 x channel_2 x ... x 4. - - Args: - - masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`) - """ - - if tf.size(masks) == 0: - return tf.zeros([*masks.shape[:-2], 4]) - - # Normalize shape to Cxheightxwidth - shape = shape_list(masks) - height, width = shape[-2:] - - # Get top and bottom edges - in_height = tf.reduce_max(masks, axis=-1) - in_height_coords = in_height * tf.range(height)[None, :] - bottom_edges = tf.reduce_max(in_height_coords, axis=-1) - in_height_coords = in_height_coords + height * (~in_height) - top_edges = tf.reduce_min(in_height_coords, axis=-1) - - # Get left and right edges - in_width, _ = tf.reduce_max(masks, axis=-2) - in_width_coords = in_width * tf.range(width)[None, :] - right_edges, _ = tf.reduce_max(in_width_coords, axis=-1) - in_width_coords = in_width_coords + width * (~in_width) - left_edges, _ = tf.reduce_min(in_width_coords, axis=-1) - - # If the mask is empty the right edge will be to the left of the left edge. - # Replace these boxes with [0, 0, 0, 0] - empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) - out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1) - out = out * tf.expand_dims(~empty_filter, -1) - - # Return to original shape - out = tf.reshape(out, *shape[:-2], 4) - return out - - def _mask_to_rle_pytorch(input_mask: "torch.Tensor"): """ Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. @@ -1389,39 +1141,6 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"): return out -def _mask_to_rle_tf(input_mask: "tf.Tensor"): - """ - Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. - """ - # Put in fortran order and flatten height and width - batch_size, height, width = input_mask.shape - input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1) - - # Compute change indices - diff = input_mask[:, 1:] ^ input_mask[:, :-1] - change_indices = tf.where(diff) - - # Encode run length - out = [] - for i in range(batch_size): - cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1 - if len(cur_idxs) == 0: - # No changes => either all 0 or all 1 - # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. - if input_mask[i, 0] == 0: - out.append({"size": [height, width], "counts": [height * width]}) - else: - out.append({"size": [height, width], "counts": [0, height * width]}) - continue - btw_idxs = cur_idxs[1:] - cur_idxs[:-1] - counts = [] if input_mask[i, 0] == 0 else [0] - counts += ( - [cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()] - ) - out.append({"size": [height, width], "counts": counts}) - return out - - def _rle_to_mask(rle: dict[str, Any]) -> np.ndarray: """Compute a binary mask from an uncompressed RLE.""" height, width = rle["size"] @@ -1465,33 +1184,4 @@ def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh= return masks, iou_scores, rle_masks, mask_boxes -def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): - """ - Perform NMS (Non Maximum Suppression) on the outputs. - - Args: - rle_masks (`tf.Tensor`): - binary masks in the RLE format - iou_scores (`tf.Tensor` of shape (nb_masks, 1)): - iou_scores predicted by the model - mask_boxes (`tf.Tensor`): - The bounding boxes corresponding to segmentation masks - amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): - NMS threshold. - """ - keep_by_nms = tf.image.combined_non_max_suppression( - boxes=mask_boxes.float(), - scores=iou_scores, - idxs=torch.zeros(mask_boxes.shape[0]), - iou_threshold=amg_crops_nms_thresh, - ) - - iou_scores = iou_scores[keep_by_nms] - rle_masks = [rle_masks[i] for i in keep_by_nms] - mask_boxes = mask_boxes[keep_by_nms] - masks = [_rle_to_mask(rle) for rle in rle_masks] - - return masks, iou_scores, rle_masks, mask_boxes - - __all__ = ["SamImageProcessor"] diff --git a/src/transformers/models/sam/processing_sam.py b/src/transformers/models/sam/processing_sam.py index 603adde95040..f7c862d82c40 100644 --- a/src/transformers/models/sam/processing_sam.py +++ b/src/transformers/models/sam/processing_sam.py @@ -24,16 +24,13 @@ from ...image_utils import ImageInput from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput -from ...utils import is_tf_available, is_torch_available +from ...utils import is_torch_available from ...video_utils import VideoInput if is_torch_available(): import torch -if is_tf_available(): - import tensorflow as tf - class SamImagesKwargs(ImagesKwargs): segmentation_maps: Optional[ImageInput] @@ -102,7 +99,7 @@ def __call__( # pop arguments that are not used in the forward but used nevertheless original_sizes = encoding_image_processor["original_sizes"] - if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor + if hasattr(original_sizes, "numpy"): original_sizes = original_sizes.numpy() input_points, input_labels, input_boxes = self._check_and_preprocess_points( @@ -173,30 +170,18 @@ def _normalize_and_convert( input_boxes = torch.from_numpy(input_boxes) # boxes batch size of 1 by default input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes - elif return_tensors == "tf": - input_boxes = tf.convert_to_tensor(input_boxes) - # boxes batch size of 1 by default - input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes encoding_image_processor.update({"input_boxes": input_boxes}) if input_points is not None: if return_tensors == "pt": input_points = torch.from_numpy(input_points) # point batch size of 1 by default input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points - elif return_tensors == "tf": - input_points = tf.convert_to_tensor(input_points) - # point batch size of 1 by default - input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points encoding_image_processor.update({"input_points": input_points}) if input_labels is not None: if return_tensors == "pt": input_labels = torch.from_numpy(input_labels) # point batch size of 1 by default input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels - elif return_tensors == "tf": - input_labels = tf.convert_to_tensor(input_labels) - # point batch size of 1 by default - input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels encoding_image_processor.update({"input_labels": input_labels}) return encoding_image_processor @@ -250,7 +235,7 @@ def _check_and_preprocess_points( it is converted to a `numpy.ndarray` and then to a `list`. """ if input_points is not None: - if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor + if hasattr(input_points, "numpy"): input_points = input_points.numpy().tolist() if not isinstance(input_points, list) or not isinstance(input_points[0], list): diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index f3c6e3fb1a2a..b7629c22933a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -16,7 +16,6 @@ import copy import math -import os import warnings from typing import Optional, Union @@ -64,126 +63,6 @@ logger = logging.get_logger(__name__) -#################################################### -# This dict contains ids and associated url -# for the pretrained weights provided with the models -#################################################### - - -#################################################### -# This is a conversion method from TF 1.0 to PyTorch -# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 -#################################################### -def load_tf_weights_in_t5(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - tf_weights = {} - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - tf_weights[name] = array - - for txt_name in names: - name = txt_name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - tf_weights.pop(txt_name, None) - continue - if "_slot_" in name[-1]: - logger.info(f"Skipping {'/'.join(name)}") - tf_weights.pop(txt_name, None) - continue - pointer = model - array = tf_weights[txt_name] - - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") - elif scope_names[0] == "self_attention": - pointer = getattr(pointer, "layer") - pointer = pointer[0] - elif scope_names[0] == "enc_dec_attention": - pointer = getattr(pointer, "layer") - pointer = pointer[1] - elif scope_names[0] == "dense_relu_dense": - pointer = getattr(pointer, "layer") - pointer = pointer[2] - elif scope_names[0] == "rms_norm": - if hasattr(pointer, "layer_norm"): - pointer = getattr(pointer, "layer_norm") - elif hasattr(pointer, "final_layer_norm"): - pointer = getattr(pointer, "final_layer_norm") - elif scope_names[0] == "scale": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - elif scope_names[0] == "decoder" and name[1] == "logits": - continue - elif scope_names[0] == "logits": - pointer = getattr(pointer, "lm_head") - elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): - pointer = getattr(pointer, f"wi_{scope_names[1]}") - continue - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if scope_names[0] not in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") - if scope_names[0] != "embedding": - logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array.astype(np.float32)) - tf_weights.pop(txt_name, None) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") - return model - - -#################################################### -# PyTorch Models are constructed by sub-classing -# - torch.nn.Module for the layers and -# - PreTrainedModel for the models (it-self a sub-class of nn.Module) -#################################################### PARALLELIZE_DOCSTRING = r""" This is an experimental feature and is a subject to change at a moment's notice. diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 075b834533b6..dc2b722789f2 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -16,7 +16,6 @@ import enum import math -import os from dataclasses import dataclass from typing import Optional, Union @@ -66,140 +65,6 @@ class TableQuestionAnsweringOutput(ModelOutput): attentions: Optional[tuple[torch.FloatTensor]] = None -def load_tf_weights_in_tapas(model, config, tf_checkpoint_path): - """ - Load tf checkpoints in a PyTorch model. This is an adaptation from load_tf_weights_in_bert - - - add cell selection and aggregation heads - - take into account additional token type embedding layers - """ - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculate m and v - # which are not required for using pretrained model - if any( - n - in [ - "adam_v", - "adam_m", - "AdamWeightDecayOptimizer", - "AdamWeightDecayOptimizer_1", - "global_step", - "seq_relationship", - ] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - # in case the model is TapasForSequenceClassification, we skip output_bias and output_weights - # since these are not used for classification - if isinstance(model, TapasForSequenceClassification): - if any(n in ["output_bias", "output_weights"] for n in name): - logger.info(f"Skipping {'/'.join(name)}") - continue - # in case the model is TapasModel, we skip output_bias, output_weights, output_bias_cls and output_weights_cls - # since this model does not have MLM and NSP heads - if isinstance(model, TapasModel): - if any(n in ["output_bias", "output_weights", "output_bias_cls", "output_weights_cls"] for n in name): - logger.info(f"Skipping {'/'.join(name)}") - continue - # in case the model is TapasForMaskedLM, we skip the pooler - if isinstance(model, TapasForMaskedLM): - if any(n in ["pooler"] for n in name): - logger.info(f"Skipping {'/'.join(name)}") - continue - # if first scope name starts with "bert", change it to "tapas" - if name[0] == "bert": - name[0] = "tapas" - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - # cell selection heads - elif scope_names[0] == "output_bias": - if not isinstance(model, TapasForMaskedLM): - pointer = getattr(pointer, "output_bias") - else: - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "output_weights") - elif scope_names[0] == "column_output_bias": - pointer = getattr(pointer, "column_output_bias") - elif scope_names[0] == "column_output_weights": - pointer = getattr(pointer, "column_output_weights") - # aggregation head - elif scope_names[0] == "output_bias_agg": - pointer = getattr(pointer, "aggregation_classifier") - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights_agg": - pointer = getattr(pointer, "aggregation_classifier") - pointer = getattr(pointer, "weight") - # classification head - elif scope_names[0] == "output_bias_cls": - pointer = getattr(pointer, "classifier") - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights_cls": - pointer = getattr(pointer, "classifier") - pointer = getattr(pointer, "weight") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name[-13:] in [f"_embeddings_{i}" for i in range(7)]: - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - # Added a check to see whether the array is a scalar (because bias terms in Tapas checkpoints can be - # scalar => should first be converted to numpy arrays) - if np.isscalar(array): - array = np.array(array) - pointer.data = torch.from_numpy(array) - return model - - class TapasEmbeddings(nn.Module): """ Construct the embeddings from word, position and token_type embeddings. Same as BertEmbeddings but with a number of @@ -699,8 +564,6 @@ class TapasPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -2052,8 +1915,6 @@ def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference, # Examples with non-empty cell selection supervision. is_cell_supervision_available = torch.sum(labels, dim=1) > 0 - # torch.where is not equivalent to tf.where (in tensorflow 1) - # hence the added .view on the condition to match the shape of the first tensor aggregate_mask = torch.where( torch.logical_and(is_pred_cell_selection, is_cell_supervision_available).view(aggregate_mask_init.size()), torch.zeros_like(aggregate_mask_init, dtype=torch.float32), @@ -2343,5 +2204,4 @@ def _calculate_regression_loss( "TapasForSequenceClassification", "TapasModel", "TapasPreTrainedModel", - "load_tf_weights_in_tapas", ] diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index 7fab73706421..bce68eacd68e 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -30,9 +30,6 @@ PaddingStrategy, TensorType, add_end_docstrings, - is_flax_available, - is_tf_available, - is_torch_available, logging, to_py_obj, ) @@ -42,12 +39,7 @@ if TYPE_CHECKING: - if is_torch_available(): - import torch - if is_tf_available(): - import tensorflow as tf - if is_flax_available(): - import jax.numpy as jnp # noqa: F401 + import torch VOCAB_FILES_NAMES = { @@ -80,7 +72,6 @@ return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. verbose (`bool`, *optional*, defaults to `True`): @@ -457,7 +448,7 @@ def _decode( # because we need docs for `output_char_offsets` here def batch_decode( self, - sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], + sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = None, output_char_offsets: bool = False, @@ -468,7 +459,7 @@ def batch_decode( Convert a list of lists of token ids into a list of strings by calling decode. Args: - sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. @@ -527,7 +518,7 @@ def batch_decode( # and `output_word_offsets` here def decode( self, - token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = None, output_char_offsets: bool = False, @@ -541,7 +532,7 @@ def decode( Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. Args: - token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. diff --git a/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py b/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py index 0715e3ce60f2..bd8a89303deb 100644 --- a/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +++ b/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py @@ -26,9 +26,6 @@ from ...tokenization_utils_base import AddedToken from ...utils import ( ModelOutput, - is_flax_available, - is_tf_available, - is_torch_available, logging, requires_backends, to_py_obj, @@ -39,12 +36,7 @@ if TYPE_CHECKING: - if is_torch_available(): - import torch - if is_tf_available(): - import tensorflow as tf - if is_flax_available(): - import jax.numpy as jnp # noqa: F401 + import torch VOCAB_FILES_NAMES = { @@ -453,7 +445,7 @@ def _decode( # overwritten from `tokenization_utils_base.py` because we need docs for `output_char_offsets` here def decode( self, - token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = None, output_char_offsets: bool = False, @@ -466,7 +458,7 @@ def decode( Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. Args: - token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. @@ -509,7 +501,7 @@ def decode( # we need docs for `output_char_offsets` here def batch_decode( self, - sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], + sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = None, output_char_offsets: bool = False, @@ -519,7 +511,7 @@ def batch_decode( Convert a list of lists of token ids into a list of strings by calling decode. Args: - sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index a73b4a51cea4..fafdd770ce12 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -619,7 +619,6 @@ def ff_chunk(self, input): @auto_docstring class XLMPreTrainedModel(PreTrainedModel): config: XLMConfig - load_tf_weights = None base_model_prefix = "transformer" def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index a3a252572ec9..4c42d7b88615 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -60,8 +60,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -674,8 +673,6 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index d0c71365d214..658083839b23 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -666,8 +666,6 @@ class XLMRobertaXLPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 0c6b9f76eade..c5ede2870711 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -36,156 +36,6 @@ logger = logging.get_logger(__name__) -def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None): - """ - A map of modules from TF to PyTorch. I use a map to keep the PyTorch model as identical to the original PyTorch - model as possible. - """ - - tf_to_pt_map = {} - - if hasattr(model, "transformer"): - if hasattr(model, "lm_loss"): - # We will load also the output bias - tf_to_pt_map["model/lm_loss/bias"] = model.lm_loss.bias - if hasattr(model, "sequence_summary") and "model/sequnece_summary/summary/kernel" in tf_weights: - # We will load also the sequence summary - tf_to_pt_map["model/sequnece_summary/summary/kernel"] = model.sequence_summary.summary.weight - tf_to_pt_map["model/sequnece_summary/summary/bias"] = model.sequence_summary.summary.bias - if ( - hasattr(model, "logits_proj") - and config.finetuning_task is not None - and f"model/regression_{config.finetuning_task}/logit/kernel" in tf_weights - ): - tf_to_pt_map[f"model/regression_{config.finetuning_task}/logit/kernel"] = model.logits_proj.weight - tf_to_pt_map[f"model/regression_{config.finetuning_task}/logit/bias"] = model.logits_proj.bias - - # Now load the rest of the transformer - model = model.transformer - - # Embeddings and output - tf_to_pt_map.update( - { - "model/transformer/word_embedding/lookup_table": model.word_embedding.weight, - "model/transformer/mask_emb/mask_emb": model.mask_emb, - } - ) - - # Transformer blocks - for i, b in enumerate(model.layer): - layer_str = f"model/transformer/layer_{i}/" - tf_to_pt_map.update( - { - layer_str + "rel_attn/LayerNorm/gamma": b.rel_attn.layer_norm.weight, - layer_str + "rel_attn/LayerNorm/beta": b.rel_attn.layer_norm.bias, - layer_str + "rel_attn/o/kernel": b.rel_attn.o, - layer_str + "rel_attn/q/kernel": b.rel_attn.q, - layer_str + "rel_attn/k/kernel": b.rel_attn.k, - layer_str + "rel_attn/r/kernel": b.rel_attn.r, - layer_str + "rel_attn/v/kernel": b.rel_attn.v, - layer_str + "ff/LayerNorm/gamma": b.ff.layer_norm.weight, - layer_str + "ff/LayerNorm/beta": b.ff.layer_norm.bias, - layer_str + "ff/layer_1/kernel": b.ff.layer_1.weight, - layer_str + "ff/layer_1/bias": b.ff.layer_1.bias, - layer_str + "ff/layer_2/kernel": b.ff.layer_2.weight, - layer_str + "ff/layer_2/bias": b.ff.layer_2.bias, - } - ) - - # Relative positioning biases - if config.untie_r: - r_r_list = [] - r_w_list = [] - r_s_list = [] - seg_embed_list = [] - for b in model.layer: - r_r_list.append(b.rel_attn.r_r_bias) - r_w_list.append(b.rel_attn.r_w_bias) - r_s_list.append(b.rel_attn.r_s_bias) - seg_embed_list.append(b.rel_attn.seg_embed) - else: - r_r_list = [model.r_r_bias] - r_w_list = [model.r_w_bias] - r_s_list = [model.r_s_bias] - seg_embed_list = [model.seg_embed] - tf_to_pt_map.update( - { - "model/transformer/r_r_bias": r_r_list, - "model/transformer/r_w_bias": r_w_list, - "model/transformer/r_s_bias": r_s_list, - "model/transformer/seg_embed": seg_embed_list, - } - ) - return tf_to_pt_map - - -def load_tf_weights_in_xlnet(model, config, tf_path): - """Load tf checkpoints in a pytorch model""" - try: - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - tf_weights = {} - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - tf_weights[name] = array - - # Build TF to PyTorch weights loading map - tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights) - - for name, pointer in tf_to_pt_map.items(): - logger.info(f"Importing {name}") - if name not in tf_weights: - logger.info(f"{name} not in tf pre-trained weights, skipping") - continue - array = tf_weights[name] - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if "kernel" in name and ("ff" in name or "summary" in name or "logit" in name): - logger.info("Transposing") - array = np.transpose(array) - if isinstance(pointer, list): - # Here we will split the TF weights - assert len(pointer) == array.shape[0], ( - f"Pointer length {len(pointer)} and array length {array.shape[0]} mismatched" - ) - for i, p_i in enumerate(pointer): - arr_i = array[i, ...] - try: - assert p_i.shape == arr_i.shape, ( - f"Pointer shape {p_i.shape} and array shape {arr_i.shape} mismatched" - ) - except AssertionError as e: - e.args += (p_i.shape, arr_i.shape) - raise - logger.info(f"Initialize PyTorch weight {name} for layer {i}") - p_i.data = torch.from_numpy(arr_i) - else: - try: - assert pointer.shape == array.shape, ( - f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - ) - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - tf_weights.pop(name, None) - tf_weights.pop(name + "/Adam", None) - tf_weights.pop(name + "/Adam_1", None) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}") - return model - - class XLNetRelativeAttention(nn.Module): def __init__(self, config): super().__init__() @@ -797,14 +647,11 @@ def forward( @auto_docstring class XLNetPreTrainedModel(PreTrainedModel): config: XLNetConfig - load_tf_weights = load_tf_weights_in_xlnet base_model_prefix = "transformer" def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -2385,5 +2232,4 @@ def forward( "XLNetLMHeadModel", "XLNetModel", "XLNetPreTrainedModel", - "load_tf_weights_in_xlnet", ] diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 7c8328447cb0..c2f25d9773e9 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -58,8 +58,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -625,8 +624,6 @@ class XmodPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/yolos/image_processing_yolos.py b/src/transformers/models/yolos/image_processing_yolos.py index 48a1300191af..8430b4cc6d67 100644 --- a/src/transformers/models/yolos/image_processing_yolos.py +++ b/src/transformers/models/yolos/image_processing_yolos.py @@ -53,11 +53,7 @@ ) from ...utils import ( TensorType, - is_flax_available, - is_jax_tensor, is_scipy_available, - is_tf_available, - is_tf_tensor, is_torch_available, is_torch_tensor, is_vision_available, @@ -227,18 +223,10 @@ def get_numpy_to_framework_fn(arr) -> Callable: """ if isinstance(arr, np.ndarray): return np.array - if is_tf_available() and is_tf_tensor(arr): - import tensorflow as tf - - return tf.convert_to_tensor if is_torch_available() and is_torch_tensor(arr): import torch return torch.tensor - if is_flax_available() and is_jax_tensor(arr): - import jax.numpy as jnp - - return jnp.array raise ValueError(f"Cannot convert arrays of type {type(arr)}") @@ -1119,10 +1107,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -1304,10 +1290,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor.") # Here the pad() method pads using the max of (width, height) and does not need to be validated. validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 92da22477f55..7cccf60ca3da 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -58,7 +58,7 @@ PipelineException, PipelineRegistry, get_default_model_and_revision, - infer_framework_load_model, + load_model, ) from .depth_estimation import DepthEstimationPipeline from .document_question_answering import DocumentQuestionAnsweringPipeline @@ -93,23 +93,6 @@ from .zero_shot_object_detection import ZeroShotObjectDetectionPipeline -if is_tf_available(): - import tensorflow as tf - - from ..models.auto.modeling_tf_auto import ( - TFAutoModel, - TFAutoModelForCausalLM, - TFAutoModelForImageClassification, - TFAutoModelForMaskedLM, - TFAutoModelForQuestionAnswering, - TFAutoModelForSeq2SeqLM, - TFAutoModelForSequenceClassification, - TFAutoModelForTableQuestionAnswering, - TFAutoModelForTokenClassification, - TFAutoModelForVision2Seq, - TFAutoModelForZeroShotImageClassification, - ) - if is_torch_available(): import torch @@ -144,7 +127,6 @@ if TYPE_CHECKING: - from ..modeling_tf_utils import TFPreTrainedModel from ..modeling_utils import PreTrainedModel from ..tokenization_utils_fast import PreTrainedTokenizerFast @@ -162,290 +144,190 @@ SUPPORTED_TASKS = { "audio-classification": { "impl": AudioClassificationPipeline, - "tf": (), "pt": (AutoModelForAudioClassification,) if is_torch_available() else (), - "default": {"model": {"pt": ("superb/wav2vec2-base-superb-ks", "372e048")}}, + "default": {"model": ("superb/wav2vec2-base-superb-ks", "372e048")}, "type": "audio", }, "automatic-speech-recognition": { "impl": AutomaticSpeechRecognitionPipeline, - "tf": (), "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (), - "default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "22aad52")}}, + "default": {"model": ("facebook/wav2vec2-base-960h", "22aad52")}, "type": "multimodal", }, "text-to-audio": { "impl": TextToAudioPipeline, - "tf": (), "pt": (AutoModelForTextToWaveform, AutoModelForTextToSpectrogram) if is_torch_available() else (), - "default": {"model": {"pt": ("suno/bark-small", "1dbd7a1")}}, + "default": {"model": ("suno/bark-small", "1dbd7a1")}, "type": "text", }, "feature-extraction": { "impl": FeatureExtractionPipeline, - "tf": (TFAutoModel,) if is_tf_available() else (), "pt": (AutoModel,) if is_torch_available() else (), - "default": { - "model": { - "pt": ("distilbert/distilbert-base-cased", "6ea8117"), - "tf": ("distilbert/distilbert-base-cased", "6ea8117"), - } - }, + "default": {"model": ("distilbert/distilbert-base-cased", "6ea8117")}, "type": "multimodal", }, "text-classification": { "impl": TextClassificationPipeline, - "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (), "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (), - "default": { - "model": { - "pt": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "714eb0f"), - "tf": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "714eb0f"), - }, - }, + "default": {"model": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "714eb0f")}, "type": "text", }, "token-classification": { "impl": TokenClassificationPipeline, - "tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (), "pt": (AutoModelForTokenClassification,) if is_torch_available() else (), - "default": { - "model": { - "pt": ("dbmdz/bert-large-cased-finetuned-conll03-english", "4c53496"), - "tf": ("dbmdz/bert-large-cased-finetuned-conll03-english", "4c53496"), - }, - }, + "default": {"model": ("dbmdz/bert-large-cased-finetuned-conll03-english", "4c53496")}, "type": "text", }, "question-answering": { "impl": QuestionAnsweringPipeline, - "tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (), "pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (), - "default": { - "model": { - "pt": ("distilbert/distilbert-base-cased-distilled-squad", "564e9b5"), - "tf": ("distilbert/distilbert-base-cased-distilled-squad", "564e9b5"), - }, - }, + "default": {"model": ("distilbert/distilbert-base-cased-distilled-squad", "564e9b5")}, "type": "text", }, "table-question-answering": { "impl": TableQuestionAnsweringPipeline, "pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (), - "tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (), - "default": { - "model": { - "pt": ("google/tapas-base-finetuned-wtq", "e3dde19"), - "tf": ("google/tapas-base-finetuned-wtq", "e3dde19"), - }, - }, + "default": {"model": ("google/tapas-base-finetuned-wtq", "e3dde19")}, "type": "text", }, "visual-question-answering": { "impl": VisualQuestionAnsweringPipeline, "pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (), - "tf": (), - "default": { - "model": {"pt": ("dandelin/vilt-b32-finetuned-vqa", "d0a1f6a")}, - }, + "default": {"model": ("dandelin/vilt-b32-finetuned-vqa", "d0a1f6a")}, "type": "multimodal", }, "document-question-answering": { "impl": DocumentQuestionAnsweringPipeline, "pt": (AutoModelForDocumentQuestionAnswering,) if is_torch_available() else (), - "tf": (), - "default": { - "model": {"pt": ("impira/layoutlm-document-qa", "beed3c4")}, - }, + "default": {"model": ("impira/layoutlm-document-qa", "beed3c4")}, "type": "multimodal", }, "fill-mask": { "impl": FillMaskPipeline, - "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (), "pt": (AutoModelForMaskedLM,) if is_torch_available() else (), - "default": { - "model": { - "pt": ("distilbert/distilroberta-base", "fb53ab8"), - "tf": ("distilbert/distilroberta-base", "fb53ab8"), - } - }, + "default": {"model": ("distilbert/distilroberta-base", "fb53ab8")}, "type": "text", }, "summarization": { "impl": SummarizationPipeline, - "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), - "default": { - "model": {"pt": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e"), "tf": ("google-t5/t5-small", "df1b051")} - }, + "default": {"model": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e")}, "type": "text", }, # This task is a special case as it's parametrized by SRC, TGT languages. "translation": { "impl": TranslationPipeline, - "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "default": { - ("en", "fr"): {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}}, - ("en", "de"): {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}}, - ("en", "ro"): {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}}, + ("en", "fr"): {"model": ("google-t5/t5-base", "a9723ea")}, + ("en", "de"): {"model": ("google-t5/t5-base", "a9723ea")}, + ("en", "ro"): {"model": ("google-t5/t5-base", "a9723ea")}, }, "type": "text", }, "text2text-generation": { "impl": Text2TextGenerationPipeline, - "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), - "default": {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}}, + "default": {"model": ("google-t5/t5-base", "a9723ea")}, "type": "text", }, "text-generation": { "impl": TextGenerationPipeline, - "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (), "pt": (AutoModelForCausalLM,) if is_torch_available() else (), - "default": {"model": {"pt": ("openai-community/gpt2", "607a30d"), "tf": ("openai-community/gpt2", "607a30d")}}, + "default": {"model": ("openai-community/gpt2", "607a30d")}, "type": "text", }, "zero-shot-classification": { "impl": ZeroShotClassificationPipeline, - "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (), "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (), "default": { - "model": { - "pt": ("facebook/bart-large-mnli", "d7645e1"), - "tf": ("FacebookAI/roberta-large-mnli", "2a8f12d"), - }, - "config": { - "pt": ("facebook/bart-large-mnli", "d7645e1"), - "tf": ("FacebookAI/roberta-large-mnli", "2a8f12d"), - }, + "model": ("facebook/bart-large-mnli", "d7645e1"), + "config": ("facebook/bart-large-mnli", "d7645e1"), }, "type": "text", }, "zero-shot-image-classification": { "impl": ZeroShotImageClassificationPipeline, - "tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (), "pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (), - "default": { - "model": { - "pt": ("openai/clip-vit-base-patch32", "3d74acf"), - "tf": ("openai/clip-vit-base-patch32", "3d74acf"), - } - }, + "default": {"model": ("openai/clip-vit-base-patch32", "3d74acf")}, "type": "multimodal", }, "zero-shot-audio-classification": { "impl": ZeroShotAudioClassificationPipeline, - "tf": (), "pt": (AutoModel,) if is_torch_available() else (), - "default": { - "model": { - "pt": ("laion/clap-htsat-fused", "cca9e28"), - } - }, + "default": {"model": ("laion/clap-htsat-fused", "cca9e28")}, "type": "multimodal", }, "image-classification": { "impl": ImageClassificationPipeline, - "tf": (TFAutoModelForImageClassification,) if is_tf_available() else (), "pt": (AutoModelForImageClassification,) if is_torch_available() else (), - "default": { - "model": { - "pt": ("google/vit-base-patch16-224", "3f49326"), - "tf": ("google/vit-base-patch16-224", "3f49326"), - } - }, + "default": {"model": ("google/vit-base-patch16-224", "3f49326")}, "type": "image", }, "image-feature-extraction": { "impl": ImageFeatureExtractionPipeline, - "tf": (TFAutoModel,) if is_tf_available() else (), "pt": (AutoModel,) if is_torch_available() else (), - "default": { - "model": { - "pt": ("google/vit-base-patch16-224", "3f49326"), - "tf": ("google/vit-base-patch16-224", "3f49326"), - } - }, + "default": {"model": ("google/vit-base-patch16-224", "3f49326")}, "type": "image", }, "image-segmentation": { "impl": ImageSegmentationPipeline, - "tf": (), "pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (), - "default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "d53b52a")}}, + "default": {"model": ("facebook/detr-resnet-50-panoptic", "d53b52a")}, "type": "multimodal", }, "image-to-text": { "impl": ImageToTextPipeline, - "tf": (TFAutoModelForVision2Seq,) if is_tf_available() else (), "pt": (AutoModelForVision2Seq,) if is_torch_available() else (), - "default": { - "model": { - "pt": ("ydshieh/vit-gpt2-coco-en", "5bebf1e"), - "tf": ("ydshieh/vit-gpt2-coco-en", "5bebf1e"), - } - }, + "default": {"model": ("ydshieh/vit-gpt2-coco-en", "5bebf1e")}, "type": "multimodal", }, "image-text-to-text": { "impl": ImageTextToTextPipeline, - "tf": (), "pt": (AutoModelForImageTextToText,) if is_torch_available() else (), - "default": { - "model": { - "pt": ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "2c9ba3b"), - } - }, + "default": {"model": ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "2c9ba3b")}, "type": "multimodal", }, "object-detection": { "impl": ObjectDetectionPipeline, - "tf": (), "pt": (AutoModelForObjectDetection,) if is_torch_available() else (), - "default": {"model": {"pt": ("facebook/detr-resnet-50", "1d5f47b")}}, + "default": {"model": ("facebook/detr-resnet-50", "1d5f47b")}, "type": "multimodal", }, "zero-shot-object-detection": { "impl": ZeroShotObjectDetectionPipeline, - "tf": (), "pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (), - "default": {"model": {"pt": ("google/owlvit-base-patch32", "cbc355f")}}, + "default": {"model": ("google/owlvit-base-patch32", "cbc355f")}, "type": "multimodal", }, "depth-estimation": { "impl": DepthEstimationPipeline, - "tf": (), "pt": (AutoModelForDepthEstimation,) if is_torch_available() else (), - "default": {"model": {"pt": ("Intel/dpt-large", "bc15f29")}}, + "default": {"model": ("Intel/dpt-large", "bc15f29")}, "type": "image", }, "video-classification": { "impl": VideoClassificationPipeline, - "tf": (), "pt": (AutoModelForVideoClassification,) if is_torch_available() else (), - "default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "488eb9a")}}, + "default": {"model": ("MCG-NJU/videomae-base-finetuned-kinetics", "488eb9a")}, "type": "video", }, "mask-generation": { "impl": MaskGenerationPipeline, - "tf": (), "pt": (AutoModelForMaskGeneration,) if is_torch_available() else (), - "default": {"model": {"pt": ("facebook/sam-vit-huge", "87aecf0")}}, + "default": {"model": ("facebook/sam-vit-huge", "87aecf0")}, "type": "multimodal", }, "image-to-image": { "impl": ImageToImagePipeline, - "tf": (), "pt": (AutoModelForImageToImage,) if is_torch_available() else (), - "default": {"model": {"pt": ("caidas/swin2SR-classical-sr-x2-64", "cee1c92")}}, + "default": {"model": ("caidas/swin2SR-classical-sr-x2-64", "cee1c92")}, "type": "image", }, "keypoint-matching": { "impl": KeypointMatchingPipeline, - "tf": (), "pt": (AutoModelForKeypointMatching,) if is_torch_available() else (), - "default": {"model": {"pt": ("magic-leap-community/superglue_outdoor", "f4041f8")}}, + "default": {"model": ("magic-leap-community/superglue_outdoor", "f4041f8")}, "type": "image", }, } @@ -545,10 +427,6 @@ def clean_custom_task(task_info): if isinstance(pt_class_names, str): pt_class_names = [pt_class_names] task_info["pt"] = tuple(getattr(transformers, c) for c in pt_class_names) - tf_class_names = task_info.get("tf", ()) - if isinstance(tf_class_names, str): - tf_class_names = [tf_class_names] - task_info["tf"] = tuple(getattr(transformers, c) for c in tf_class_names) return task_info, None @@ -565,67 +443,67 @@ def clean_custom_task(task_info): @overload -def pipeline(task: Literal[None], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> Pipeline: ... +def pipeline(task: Literal[None], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> Pipeline: ... @overload -def pipeline(task: Literal["audio-classification"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> AudioClassificationPipeline: ... +def pipeline(task: Literal["audio-classification"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> AudioClassificationPipeline: ... @overload -def pipeline(task: Literal["automatic-speech-recognition"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> AutomaticSpeechRecognitionPipeline: ... +def pipeline(task: Literal["automatic-speech-recognition"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> AutomaticSpeechRecognitionPipeline: ... @overload -def pipeline(task: Literal["depth-estimation"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> DepthEstimationPipeline: ... +def pipeline(task: Literal["depth-estimation"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> DepthEstimationPipeline: ... @overload -def pipeline(task: Literal["document-question-answering"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> DocumentQuestionAnsweringPipeline: ... +def pipeline(task: Literal["document-question-answering"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> DocumentQuestionAnsweringPipeline: ... @overload -def pipeline(task: Literal["feature-extraction"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> FeatureExtractionPipeline: ... +def pipeline(task: Literal["feature-extraction"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> FeatureExtractionPipeline: ... @overload -def pipeline(task: Literal["fill-mask"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> FillMaskPipeline: ... +def pipeline(task: Literal["fill-mask"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> FillMaskPipeline: ... @overload -def pipeline(task: Literal["image-classification"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageClassificationPipeline: ... +def pipeline(task: Literal["image-classification"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageClassificationPipeline: ... @overload -def pipeline(task: Literal["image-feature-extraction"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageFeatureExtractionPipeline: ... +def pipeline(task: Literal["image-feature-extraction"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageFeatureExtractionPipeline: ... @overload -def pipeline(task: Literal["image-segmentation"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageSegmentationPipeline: ... +def pipeline(task: Literal["image-segmentation"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageSegmentationPipeline: ... @overload -def pipeline(task: Literal["image-text-to-text"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageTextToTextPipeline: ... +def pipeline(task: Literal["image-text-to-text"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageTextToTextPipeline: ... @overload -def pipeline(task: Literal["image-to-image"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageToImagePipeline: ... +def pipeline(task: Literal["image-to-image"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageToImagePipeline: ... @overload -def pipeline(task: Literal["image-to-text"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageToTextPipeline: ... +def pipeline(task: Literal["image-to-text"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ImageToTextPipeline: ... @overload -def pipeline(task: Literal["keypoint-matching"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> KeypointMatchingPipeline: ... +def pipeline(task: Literal["keypoint-matching"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> KeypointMatchingPipeline: ... @overload -def pipeline(task: Literal["mask-generation"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> MaskGenerationPipeline: ... +def pipeline(task: Literal["mask-generation"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> MaskGenerationPipeline: ... @overload -def pipeline(task: Literal["object-detection"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ObjectDetectionPipeline: ... +def pipeline(task: Literal["object-detection"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ObjectDetectionPipeline: ... @overload -def pipeline(task: Literal["question-answering"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> QuestionAnsweringPipeline: ... +def pipeline(task: Literal["question-answering"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> QuestionAnsweringPipeline: ... @overload -def pipeline(task: Literal["summarization"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> SummarizationPipeline: ... +def pipeline(task: Literal["summarization"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> SummarizationPipeline: ... @overload -def pipeline(task: Literal["table-question-answering"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TableQuestionAnsweringPipeline: ... +def pipeline(task: Literal["table-question-answering"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TableQuestionAnsweringPipeline: ... @overload -def pipeline(task: Literal["text-classification"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TextClassificationPipeline: ... +def pipeline(task: Literal["text-classification"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TextClassificationPipeline: ... @overload -def pipeline(task: Literal["text-generation"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TextGenerationPipeline: ... +def pipeline(task: Literal["text-generation"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TextGenerationPipeline: ... @overload -def pipeline(task: Literal["text-to-audio"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TextToAudioPipeline: ... +def pipeline(task: Literal["text-to-audio"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TextToAudioPipeline: ... @overload -def pipeline(task: Literal["text2text-generation"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> Text2TextGenerationPipeline: ... +def pipeline(task: Literal["text2text-generation"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> Text2TextGenerationPipeline: ... @overload -def pipeline(task: Literal["token-classification"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TokenClassificationPipeline: ... +def pipeline(task: Literal["token-classification"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TokenClassificationPipeline: ... @overload -def pipeline(task: Literal["translation"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TranslationPipeline: ... +def pipeline(task: Literal["translation"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> TranslationPipeline: ... @overload -def pipeline(task: Literal["video-classification"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> VideoClassificationPipeline: ... +def pipeline(task: Literal["video-classification"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> VideoClassificationPipeline: ... @overload -def pipeline(task: Literal["visual-question-answering"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> VisualQuestionAnsweringPipeline: ... +def pipeline(task: Literal["visual-question-answering"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> VisualQuestionAnsweringPipeline: ... @overload -def pipeline(task: Literal["zero-shot-audio-classification"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ZeroShotAudioClassificationPipeline: ... +def pipeline(task: Literal["zero-shot-audio-classification"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ZeroShotAudioClassificationPipeline: ... @overload -def pipeline(task: Literal["zero-shot-classification"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ZeroShotClassificationPipeline: ... +def pipeline(task: Literal["zero-shot-classification"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ZeroShotClassificationPipeline: ... @overload -def pipeline(task: Literal["zero-shot-image-classification"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ZeroShotImageClassificationPipeline: ... +def pipeline(task: Literal["zero-shot-image-classification"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ZeroShotImageClassificationPipeline: ... @overload -def pipeline(task: Literal["zero-shot-object-detection"], model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ZeroShotObjectDetectionPipeline: ... +def pipeline(task: Literal["zero-shot-object-detection"], model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, device: Optional[Union[int, str, "torch.device"]] = None, device_map: Optional[Union[str, dict[str, Union[int, str]]]] = None, dtype: Optional[Union[str, "torch.dtype"]] = "auto", trust_remote_code: Optional[bool] = None, model_kwargs: Optional[dict[str, Any]] = None, pipeline_class: Optional[Any] = None, **kwargs: Any) -> ZeroShotObjectDetectionPipeline: ... # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # The part of the file above was automatically generated from the code. @@ -636,13 +514,12 @@ def pipeline(task: Literal["zero-shot-object-detection"], model: Optional[Union[ def pipeline( task: Optional[str] = None, - model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None, + model: Optional[Union[str, "PreTrainedModel"]] = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, processor: Optional[Union[str, ProcessorMixin]] = None, - framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, token: Optional[Union[str, bool]] = None, @@ -759,13 +636,6 @@ def pipeline( If not provided, the default processor for the given `model` will be loaded (if it is a string). If `model` is not specified or not a string, then the default processor for `config` is loaded (if it is a string). - framework (`str`, *optional*): - The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be - installed. - - If no framework is specified, will default to the one currently installed. If no framework is specified and - both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is - provided. revision (`str`, *optional*, defaults to `"main"`): When passing a task name or a string model identifier: The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other @@ -969,8 +839,7 @@ def pipeline( # Use default model/config/tokenizer for the task if no model is provided if model is None: - # At that point framework might still be undetermined - model, default_revision = get_default_model_and_revision(targeted_task, framework, task_options) + model, default_revision = get_default_model_and_revision(targeted_task, task_options) revision = revision if revision is not None else default_revision logger.warning( f"No model was supplied, defaulted to {model} and revision" @@ -1022,14 +891,12 @@ def pipeline( model_name = model if isinstance(model, str) else None # Load the correct model if possible - # Infer the framework from the model if not already defined - if isinstance(model, str) or framework is None: - model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]} - framework, model = infer_framework_load_model( + if isinstance(model, str): + model_classes = targeted_task["pt"] + model = load_model( adapter_path if adapter_path is not None else model, model_classes=model_classes, config=config, - framework=framework, task=task, **hub_kwargs, **model_kwargs, @@ -1227,4 +1094,4 @@ def pipeline( if processor is not None: kwargs["processor"] = processor - return pipeline_class(model=model, framework=framework, task=task, **kwargs) + return pipeline_class(model=model, task=task, **kwargs) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 944c7a90a184..09eaab16922f 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -42,8 +42,6 @@ PushToHubMixin, add_end_docstrings, copy_func, - infer_framework, - is_tf_available, is_torch_available, is_torch_cuda_available, is_torch_hpu_available, @@ -54,15 +52,9 @@ is_torch_xpu_available, logging, ) -from ..utils.deprecation import deprecate_kwarg -GenericTensor = Union[list["GenericTensor"], "torch.Tensor", "tf.Tensor"] - -if is_tf_available(): - import tensorflow as tf - - from ..models.auto.modeling_tf_auto import TFAutoModel +GenericTensor = Union[list["GenericTensor"], "torch.Tensor"] if is_torch_available(): import torch @@ -78,7 +70,6 @@ KeyDataset = None if TYPE_CHECKING: - from ..modeling_tf_utils import TFPreTrainedModel from ..modeling_utils import PreTrainedModel @@ -207,30 +198,27 @@ def inner(items): return inner -def infer_framework_load_model( +def load_model( model, config: AutoConfig, - model_classes: Optional[dict[str, tuple[type]]] = None, + model_classes: Optional[tuple[type]] = None, task: Optional[str] = None, - framework: Optional[str] = None, **model_kwargs, ): """ - Select framework (TensorFlow or PyTorch) to use from the `model` passed. Returns a tuple (framework, model). + Load a model. - If `model` is instantiated, this function will just infer the framework from the model class. Otherwise `model` is + If `model` is instantiated, this function will just return it. Otherwise `model` is actually a checkpoint name and this method will try to instantiate it using `model_classes`. Since we don't want to instantiate the model twice, this model is returned for use by the pipeline. - If both frameworks are installed and available for `model`, PyTorch is selected. - Args: - model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`): - The model to infer the framework from. If `str`, a checkpoint name. The model to infer the framewrok from. + model (`str`, or [`PreTrainedModel`]): + If `str`, a checkpoint name. The model to load. config ([`AutoConfig`]): The config associated with the model to help using the correct class - model_classes (dictionary `str` to `type`, *optional*): - A mapping framework to class. + model_classes (`tuple[type]`, *optional*): + A tuple of model classes. task (`str`): The task defining which pipeline will be returned. model_kwargs: @@ -238,36 +226,20 @@ def infer_framework_load_model( **model_kwargs)` function. Returns: - `Tuple`: A tuple framework, model. + The model. """ - if not is_tf_available() and not is_torch_available(): - raise RuntimeError( - "At least one of TensorFlow 2.0 or PyTorch should be installed. " - "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " - "To install PyTorch, read the instructions at https://pytorch.org/." - ) + if not is_torch_available(): + raise RuntimeError("PyTorch should be installed. Please follow the instructions at https://pytorch.org/.") if isinstance(model, str): model_kwargs["_from_pipeline"] = task - class_tuple = () - look_pt = is_torch_available() and framework in {"pt", None} - look_tf = is_tf_available() and framework in {"tf", None} - if model_classes: - if look_pt: - class_tuple = class_tuple + model_classes.get("pt", (AutoModel,)) - if look_tf: - class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,)) + class_tuple = model_classes if model_classes is not None else (AutoModel,) if config.architectures: classes = [] for architecture in config.architectures: transformers_module = importlib.import_module("transformers") - if look_pt: - _class = getattr(transformers_module, architecture, None) - if _class is not None: - classes.append(_class) - if look_tf: - _class = getattr(transformers_module, f"TF{architecture}", None) - if _class is not None: - classes.append(_class) + _class = getattr(transformers_module, architecture, None) + if _class is not None: + classes.append(_class) class_tuple = class_tuple + tuple(classes) if len(class_tuple) == 0: @@ -276,23 +248,9 @@ def infer_framework_load_model( all_traceback = {} for model_class in class_tuple: kwargs = model_kwargs.copy() - if framework == "pt" and model.endswith(".h5"): - kwargs["from_tf"] = True - logger.warning( - "Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. " - "Trying to load the model with PyTorch." - ) - elif framework == "tf" and model.endswith(".bin"): - kwargs["from_pt"] = True - logger.warning( - "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. " - "Trying to load the model with Tensorflow." - ) try: model = model_class.from_pretrained(model, **kwargs) - if hasattr(model, "eval"): - model = model.eval() # Stop loading on the first successful load. break except (OSError, ValueError, TypeError, RuntimeError): @@ -300,7 +258,7 @@ def infer_framework_load_model( # is not supported on the execution device (e.g. bf16 on a consumer GPU). We capture those so # we can transparently retry the load in float32 before surfacing an error to the user. fallback_tried = False - if is_torch_available() and ("dtype" in kwargs): + if "dtype" in kwargs: import torch # local import to avoid unnecessarily importing torch for TF/JAX users fallback_tried = True @@ -309,8 +267,6 @@ def infer_framework_load_model( try: model = model_class.from_pretrained(model, **fp32_kwargs) - if hasattr(model, "eval"): - model = model.eval() logger.warning( "Falling back to torch.float32 because loading with the original dtype failed on the" " target device." @@ -334,97 +290,17 @@ def infer_framework_load_model( f"Could not load model {model} with any of the following classes: {class_tuple}. See the original errors:\n\n{error}\n" ) - if framework is None: - framework = infer_framework(model.__class__) - return framework, model - - -def infer_framework_from_model( - model, - model_classes: Optional[dict[str, tuple[type]]] = None, - task: Optional[str] = None, - framework: Optional[str] = None, - **model_kwargs, -): - """ - Select framework (TensorFlow or PyTorch) to use from the `model` passed. Returns a tuple (framework, model). - - If `model` is instantiated, this function will just infer the framework from the model class. Otherwise `model` is - actually a checkpoint name and this method will try to instantiate it using `model_classes`. Since we don't want to - instantiate the model twice, this model is returned for use by the pipeline. - - If both frameworks are installed and available for `model`, PyTorch is selected. - - Args: - model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`): - The model to infer the framework from. If `str`, a checkpoint name. The model to infer the framewrok from. - model_classes (dictionary `str` to `type`, *optional*): - A mapping framework to class. - task (`str`): - The task defining which pipeline will be returned. - model_kwargs: - Additional dictionary of keyword arguments passed along to the model's `from_pretrained(..., - **model_kwargs)` function. - - Returns: - `Tuple`: A tuple framework, model. - """ - if isinstance(model, str): - config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs) - else: - config = model.config - return infer_framework_load_model( - model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs - ) - + return model -def get_framework(model, revision: Optional[str] = None): - """ - Select framework (TensorFlow or PyTorch) to use. - Args: - model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`): - If both frameworks are installed, picks the one corresponding to the model passed (either a model class or - the model name). If no specific model is provided, defaults to using PyTorch. +def get_default_model_and_revision(targeted_task: dict, task_options: Optional[Any]) -> tuple[str, str]: """ - warnings.warn( - "`get_framework` is deprecated and will be removed in v5, use `infer_framework_from_model` instead.", - FutureWarning, - ) - if not is_tf_available() and not is_torch_available(): - raise RuntimeError( - "At least one of TensorFlow 2.0 or PyTorch should be installed. " - "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " - "To install PyTorch, read the instructions at https://pytorch.org/." - ) - if isinstance(model, str): - if is_torch_available() and not is_tf_available(): - model = AutoModel.from_pretrained(model, revision=revision) - elif is_tf_available() and not is_torch_available(): - model = TFAutoModel.from_pretrained(model, revision=revision) - else: - try: - model = AutoModel.from_pretrained(model, revision=revision) - except OSError: - model = TFAutoModel.from_pretrained(model, revision=revision) - - framework = infer_framework(model.__class__) - return framework - - -def get_default_model_and_revision( - targeted_task: dict, framework: Optional[str], task_options: Optional[Any] -) -> tuple[str, str]: - """ - Select a default model to use for a given task. Defaults to pytorch if ambiguous. + Select a default model to use for a given task. Args: targeted_task (`Dict`): Dictionary representing the given task, that should contain default models - framework (`str`, None) - "pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet. - task_options (`Any`, None) Any further value required by the task to get fully specified, for instance (SRC, TGT) languages for translation task. @@ -435,11 +311,6 @@ def get_default_model_and_revision( - `str` The model string representing the default model for this pipeline. - `str` The revision of the model. """ - if is_torch_available() and not is_tf_available(): - framework = "pt" - elif is_tf_available() and not is_torch_available(): - framework = "tf" - defaults = targeted_task["default"] if task_options: if task_options not in defaults: @@ -452,10 +323,7 @@ def get_default_model_and_revision( # parametrized raise ValueError('The task defaults can\'t be correctly selected. You probably meant "translation_xx_to_yy"') - if framework is None: - framework = "pt" - - return default_models[framework] + return default_models def load_assistant_model( @@ -480,16 +348,10 @@ def load_assistant_model( if not model.can_generate() or assistant_model is None: return None, None - if getattr(model, "framework") != "pt" or not isinstance(model, PreTrainedModel): - raise ValueError( - "Assisted generation, triggered by the `assistant_model` argument, is only available for " - "`PreTrainedModel` model instances. For instance, TF or JAX models are not supported." - ) - # If the model is passed as a string, load the model and the corresponding tokenizer if isinstance(assistant_model, str): assistant_config = AutoConfig.from_pretrained(assistant_model) - _, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config) + loaded_assistant_model = load_model(assistant_model, config=assistant_config) loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype) loaded_assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model) else: @@ -838,13 +700,6 @@ def build_pipeline_init_args( docstring += r""" modelcard (`str` or [`ModelCard`], *optional*): Model card attributed to the model for this pipeline. - framework (`str`, *optional*): - The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be - installed. - - If no framework is specified, will default to the one currently installed. If no framework is specified and - both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is - provided. task (`str`, defaults to `""`): A task-identifier for the pipeline. num_workers (`int`, *optional*, defaults to 8): @@ -943,13 +798,12 @@ class Pipeline(_ScikitCompat, PushToHubMixin): def __init__( self, - model: Union["PreTrainedModel", "TFPreTrainedModel"], + model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizer] = None, feature_extractor: Optional[PreTrainedFeatureExtractor] = None, image_processor: Optional[BaseImageProcessor] = None, processor: Optional[ProcessorMixin] = None, modelcard: Optional[ModelCard] = None, - framework: Optional[str] = None, task: str = "", device: Optional[Union[int, "torch.device"]] = None, binary_output: bool = False, @@ -957,13 +811,6 @@ def __init__( ): # We need to pop them for _sanitize_parameters call later _, _, _ = kwargs.pop("args_parser", None), kwargs.pop("torch_dtype", None), kwargs.pop("dtype", None) - if framework is None: - framework, model = infer_framework_load_model(model, config=model.config) - if framework in ("tf", "jax"): - logger.warning_once( - "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " - "recommend migrating to PyTorch classes or pinning your version of Transformers." - ) self.task = task self.model = model @@ -972,7 +819,6 @@ def __init__( self.image_processor = image_processor self.processor = processor self.modelcard = modelcard - self.framework = framework # `accelerate` device map hf_device_map = getattr(self.model, "hf_device_map", None) @@ -990,45 +836,42 @@ def __init__( else: device = 0 - if is_torch_available() and self.framework == "pt": - if device == -1 and self.model.device is not None: - device = self.model.device - if isinstance(device, torch.device): - if (device.type == "xpu" and not is_torch_xpu_available(check_device=True)) or ( - device.type == "hpu" and not is_torch_hpu_available() - ): - raise ValueError(f'{device} is not available, you should use device="cpu" instead') + if device == -1 and self.model.device is not None: + device = self.model.device + if isinstance(device, torch.device): + if (device.type == "xpu" and not is_torch_xpu_available(check_device=True)) or ( + device.type == "hpu" and not is_torch_hpu_available() + ): + raise ValueError(f'{device} is not available, you should use device="cpu" instead') - self.device = device - elif isinstance(device, str): - if ("xpu" in device and not is_torch_xpu_available(check_device=True)) or ( - "hpu" in device and not is_torch_hpu_available() - ): - raise ValueError(f'{device} is not available, you should use device="cpu" instead') - - self.device = torch.device(device) - elif device < 0: - self.device = torch.device("cpu") - elif is_torch_mlu_available(): - self.device = torch.device(f"mlu:{device}") - elif is_torch_musa_available(): - self.device = torch.device(f"musa:{device}") - elif is_torch_cuda_available(): - self.device = torch.device(f"cuda:{device}") - elif is_torch_npu_available(): - self.device = torch.device(f"npu:{device}") - elif is_torch_hpu_available(): - self.device = torch.device(f"hpu:{device}") - elif is_torch_xpu_available(check_device=True): - self.device = torch.device(f"xpu:{device}") - elif is_torch_mps_available(): - self.device = torch.device(f"mps:{device}") - else: - self.device = torch.device("cpu") + self.device = device + elif isinstance(device, str): + if ("xpu" in device and not is_torch_xpu_available(check_device=True)) or ( + "hpu" in device and not is_torch_hpu_available() + ): + raise ValueError(f'{device} is not available, you should use device="cpu" instead') + + self.device = torch.device(device) + elif device < 0: + self.device = torch.device("cpu") + elif is_torch_mlu_available(): + self.device = torch.device(f"mlu:{device}") + elif is_torch_musa_available(): + self.device = torch.device(f"musa:{device}") + elif is_torch_cuda_available(): + self.device = torch.device(f"cuda:{device}") + elif is_torch_npu_available(): + self.device = torch.device(f"npu:{device}") + elif is_torch_hpu_available(): + self.device = torch.device(f"hpu:{device}") + elif is_torch_xpu_available(check_device=True): + self.device = torch.device(f"xpu:{device}") + elif is_torch_mps_available(): + self.device = torch.device(f"mps:{device}") else: - self.device = device if device is not None else -1 + self.device = torch.device("cpu") - if is_torch_available() and torch.distributed.is_available() and torch.distributed.is_initialized(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): self.device = self.model.device logger.warning(f"Device set to use {self.device}") @@ -1036,8 +879,7 @@ def __init__( # We shouldn't call `model.to()` for models loaded with accelerate as well as the case that model is already on device if ( - self.framework == "pt" - and self.model.device != self.device + self.model.device != self.device and not (isinstance(self.device, int) and self.device < 0) and hf_device_map is None ): @@ -1167,7 +1009,6 @@ def save_pretrained( # Change classes into their names/full names info["impl"] = f"{last_module}.{info['impl'].__name__}" info["pt"] = tuple(c.__name__ for c in info["pt"]) - info["tf"] = tuple(c.__name__ for c in info["tf"]) custom_pipelines[task] = info self.model.config.custom_pipelines = custom_pipelines @@ -1219,7 +1060,7 @@ def torch_dtype(self) -> Optional["torch.dtype"]: @contextmanager def device_placement(self): """ - Context Manager allowing tensor allocation on the user-specified device in framework agnostic way. + Context Manager allowing tensor allocation on the user-specified device. Returns: Context manager @@ -1230,27 +1071,23 @@ def device_placement(self): # Explicitly ask for tensor allocation on CUDA device :0 pipe = pipeline(..., device=0) with pipe.device_placement(): - # Every framework specific tensor allocation will be done on the request device + # Every tensor allocation will be done on the request device output = pipe(...) ```""" - if self.framework == "tf": - with tf.device("/CPU:0" if self.device == -1 else f"/device:GPU:{self.device}"): + if self.device.type == "cuda": + with torch.cuda.device(self.device): yield - else: - if self.device.type == "cuda": - with torch.cuda.device(self.device): - yield - elif self.device.type == "mlu": - with torch.mlu.device(self.device): - yield - elif self.device.type == "musa": - with torch.musa.device(self.device): - yield - elif self.device.type == "xpu": - with torch.xpu.device(self.device): - yield - else: + elif self.device.type == "mlu": + with torch.mlu.device(self.device): + yield + elif self.device.type == "musa": + with torch.musa.device(self.device): + yield + elif self.device.type == "xpu": + with torch.xpu.device(self.device): yield + else: + yield def ensure_tensor_on_device(self, **inputs): """ @@ -1364,17 +1201,11 @@ def get_inference_context(self): def forward(self, model_inputs, **forward_params): with self.device_placement(): - if self.framework == "tf": - model_inputs["training"] = False + inference_context = self.get_inference_context() + with inference_context(): + model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device) model_outputs = self._forward(model_inputs, **forward_params) - elif self.framework == "pt": - inference_context = self.get_inference_context() - with inference_context(): - model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device) - model_outputs = self._forward(model_inputs, **forward_params) - model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu")) - else: - raise ValueError(f"Framework {self.framework} is not supported") + model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu")) return model_outputs def get_iterator( @@ -1425,7 +1256,7 @@ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs): postprocess_params = {**self._postprocess_params, **postprocess_params} self.call_count += 1 - if self.call_count > 10 and self.framework == "pt" and self.device.type == "cuda": + if self.call_count > 10 and self.device.type == "cuda": logger.warning_once( "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a" " dataset", @@ -1436,9 +1267,7 @@ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs): is_list = isinstance(inputs, list) is_iterable = is_dataset or is_generator or is_list - - # TODO make the get_iterator work also for `tf` (and `flax`). - can_use_iterator = self.framework == "pt" and (is_dataset or is_generator or is_list) + can_use_iterator = is_dataset or is_generator or is_list if is_list: if can_use_iterator: @@ -1455,7 +1284,7 @@ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs): ) elif is_iterable: return self.iterate(inputs, preprocess_params, forward_params, postprocess_params) - elif self.framework == "pt" and isinstance(self, ChunkPipeline): + elif isinstance(self, ChunkPipeline): return next( iter( self.get_iterator( @@ -1550,13 +1379,11 @@ def check_task(self, task: str) -> tuple[str, dict, Any]: f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}" ) - @deprecate_kwarg(old_name="tf_model", version="5.0.0") def register_pipeline( self, task: str, pipeline_class: type, pt_model: Optional[Union[type, tuple[type]]] = None, - tf_model: Optional[Union[type, tuple[type]]] = None, default: Optional[dict] = None, type: Optional[str] = None, ) -> None: @@ -1568,15 +1395,10 @@ def register_pipeline( elif not isinstance(pt_model, tuple): pt_model = (pt_model,) - if tf_model is None: - tf_model = () - elif not isinstance(tf_model, tuple): - tf_model = (tf_model,) - - task_impl = {"impl": pipeline_class, "pt": pt_model, "tf": tf_model} + task_impl = {"impl": pipeline_class, "pt": pt_model} if default is not None: - if "model" not in default and ("pt" in default or "tf" in default): + if "model" not in default: default = {"model": default} task_impl["default"] = default diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py index cc69cf6d2792..49a45b5a7f5c 100644 --- a/src/transformers/pipelines/fill_mask.py +++ b/src/transformers/pipelines/fill_mask.py @@ -2,16 +2,10 @@ import numpy as np -from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging +from ..utils import add_end_docstrings, is_torch_available, logging from .base import GenericTensor, Pipeline, PipelineException, build_pipeline_init_args -if is_tf_available(): - import tensorflow as tf - - from ..tf_utils import stable_softmax - - if is_torch_available(): import torch @@ -90,12 +84,7 @@ class FillMaskPipeline(Pipeline): """ def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray: - if self.framework == "tf": - masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy() - elif self.framework == "pt": - masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False) - else: - raise ValueError("Unsupported framework") + masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False) return masked_index def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray: @@ -140,29 +129,15 @@ def postprocess(self, model_outputs, top_k=5, target_ids=None): input_ids = model_outputs["input_ids"][0] outputs = model_outputs["logits"] - if self.framework == "tf": - masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()[:, 0] - - outputs = outputs.numpy() - - logits = outputs[0, masked_index, :] - probs = stable_softmax(logits, axis=-1) - if target_ids is not None: - probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1)) - probs = tf.expand_dims(probs, 0) - - topk = tf.math.top_k(probs, k=top_k) - values, predictions = topk.values.numpy(), topk.indices.numpy() - else: - masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1) - # Fill mask pipeline supports only one ${mask_token} per sample + masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1) + # Fill mask pipeline supports only one ${mask_token} per sample - logits = outputs[0, masked_index, :] - probs = logits.softmax(dim=-1) - if target_ids is not None: - probs = probs[..., target_ids] + logits = outputs[0, masked_index, :] + probs = logits.softmax(dim=-1) + if target_ids is not None: + probs = probs[..., target_ids] - values, predictions = probs.topk(top_k) + values, predictions = probs.topk(top_k) result = [] single_mask = values.shape[0] == 1 diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index ee86074a4c58..5eeeb51cf389 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -12,7 +12,6 @@ from ..utils import ( PaddingStrategy, add_end_docstrings, - is_tf_available, is_tokenizers_available, is_torch_available, logging, @@ -29,12 +28,6 @@ if is_tokenizers_available(): import tokenizers -if is_tf_available(): - import tensorflow as tf - - from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES - - Dataset = None if is_torch_available(): import torch @@ -268,7 +261,6 @@ def __init__( model: Union["PreTrainedModel", "TFPreTrainedModel"], tokenizer: PreTrainedTokenizer, modelcard: Optional[ModelCard] = None, - framework: Optional[str] = None, task: str = "", **kwargs, ): @@ -276,17 +268,12 @@ def __init__( model=model, tokenizer=tokenizer, modelcard=modelcard, - framework=framework, task=task, **kwargs, ) self._args_parser = QuestionAnsweringArgumentHandler() - self.check_model_type( - TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES - if self.framework == "tf" - else MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES - ) + self.check_model_type(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES) @staticmethod def create_sample( @@ -503,16 +490,10 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_questio for k, v in feature.__dict__.items(): if k in model_input_names: - if self.framework == "tf": - tensor = tf.constant(v) - if tensor.dtype == tf.int64: - tensor = tf.cast(tensor, tf.int32) - fw_args[k] = tf.expand_dims(tensor, 0) - elif self.framework == "pt": - tensor = torch.tensor(v) - if tensor.dtype == torch.int32: - tensor = tensor.long() - fw_args[k] = tensor.unsqueeze(0) + tensor = torch.tensor(v) + if tensor.dtype == torch.int32: + tensor = tensor.long() + fw_args[k] = tensor.unsqueeze(0) else: others[k] = v @@ -523,7 +504,7 @@ def _forward(self, inputs): example = inputs["example"] model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names} # `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported - model_forward = self.model.forward if self.framework == "pt" else self.model.call + model_forward = self.model.forward if "use_cache" in inspect.signature(model_forward).parameters: model_inputs["use_cache"] = False output = self.model(**model_inputs) @@ -544,7 +525,7 @@ def postprocess( min_null_score = 1000000 # large and positive answers = [] for output in model_outputs: - if self.framework == "pt" and output["start"].dtype == torch.bfloat16: + if output["start"].dtype == torch.bfloat16: start_ = output["start"].to(torch.float32) end_ = output["end"].to(torch.float32) else: diff --git a/src/transformers/pipelines/table_question_answering.py b/src/transformers/pipelines/table_question_answering.py index da579423d2d4..04190b552910 100644 --- a/src/transformers/pipelines/table_question_answering.py +++ b/src/transformers/pipelines/table_question_answering.py @@ -6,7 +6,6 @@ from ..generation import GenerationConfig from ..utils import ( add_end_docstrings, - is_tf_available, is_torch_available, requires_backends, ) @@ -21,14 +20,6 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, ) -if is_tf_available(): - import tensorflow as tf - - from ..models.auto.modeling_tf_auto import ( - TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, - TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, - ) - class TableQuestionAnsweringArgumentHandler(ArgumentHandler): """ @@ -135,12 +126,8 @@ def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), *args, * super().__init__(*args, **kwargs) self._args_parser = args_parser - if self.framework == "tf": - mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy() - mapping.update(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES) - else: - mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy() - mapping.update(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES) + mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy() + mapping.update(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES) self.check_model_type(mapping) self.aggregate = getattr(self.model.config, "aggregation_labels", None) and getattr( @@ -156,129 +143,67 @@ def sequential_inference(self, **inputs): Inference used for models that need to process sequences in a sequential fashion, like the SQA models which handle conversational query related to a table. """ - if self.framework == "pt": - all_logits = [] - all_aggregations = [] - prev_answers = None - batch_size = inputs["input_ids"].shape[0] - - input_ids = inputs["input_ids"].to(self.device) - attention_mask = inputs["attention_mask"].to(self.device) - token_type_ids = inputs["token_type_ids"].to(self.device) - token_type_ids_example = None - - for index in range(batch_size): - # If sequences have already been processed, the token type IDs will be created according to the previous - # answer. - if prev_answers is not None: - prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,) - model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,) - - token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) - for i in range(model_labels.shape[0]): - segment_id = token_type_ids_example[:, 0].tolist()[i] - col_id = token_type_ids_example[:, 1].tolist()[i] - 1 - row_id = token_type_ids_example[:, 2].tolist()[i] - 1 - - if row_id >= 0 and col_id >= 0 and segment_id == 1: - model_labels[i] = int(prev_answers[(col_id, row_id)]) - - token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device) - - input_ids_example = input_ids[index] - attention_mask_example = attention_mask[index] # shape (seq_len,) - token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) - outputs = self.model( - input_ids=input_ids_example.unsqueeze(0), - attention_mask=attention_mask_example.unsqueeze(0), - token_type_ids=token_type_ids_example.unsqueeze(0), - ) - logits = outputs.logits - - if self.aggregate: - all_aggregations.append(outputs.logits_aggregation) + all_logits = [] + all_aggregations = [] + prev_answers = None + batch_size = inputs["input_ids"].shape[0] + + input_ids = inputs["input_ids"].to(self.device) + attention_mask = inputs["attention_mask"].to(self.device) + token_type_ids = inputs["token_type_ids"].to(self.device) + token_type_ids_example = None + + for index in range(batch_size): + # If sequences have already been processed, the token type IDs will be created according to the previous + # answer. + if prev_answers is not None: + prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,) + model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,) - all_logits.append(logits) - - dist_per_token = torch.distributions.Bernoulli(logits=logits) - probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to( - dist_per_token.probs.device - ) - - coords_to_probs = collections.defaultdict(list) - for i, p in enumerate(probabilities.squeeze().tolist()): + token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) + for i in range(model_labels.shape[0]): segment_id = token_type_ids_example[:, 0].tolist()[i] - col = token_type_ids_example[:, 1].tolist()[i] - 1 - row = token_type_ids_example[:, 2].tolist()[i] - 1 - if col >= 0 and row >= 0 and segment_id == 1: - coords_to_probs[(col, row)].append(p) + col_id = token_type_ids_example[:, 1].tolist()[i] - 1 + row_id = token_type_ids_example[:, 2].tolist()[i] - 1 - prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs} + if row_id >= 0 and col_id >= 0 and segment_id == 1: + model_labels[i] = int(prev_answers[(col_id, row_id)]) - logits_batch = torch.cat(tuple(all_logits), 0) + token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device) - return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0)) - else: - all_logits = [] - all_aggregations = [] - prev_answers = None - batch_size = inputs["input_ids"].shape[0] - - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - token_type_ids = inputs["token_type_ids"].numpy() - token_type_ids_example = None - - for index in range(batch_size): - # If sequences have already been processed, the token type IDs will be created according to the previous - # answer. - if prev_answers is not None: - prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,) - model_labels = np.zeros_like(prev_labels_example, dtype=np.int32) # shape (seq_len,) - - token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) - for i in range(model_labels.shape[0]): - segment_id = token_type_ids_example[:, 0].tolist()[i] - col_id = token_type_ids_example[:, 1].tolist()[i] - 1 - row_id = token_type_ids_example[:, 2].tolist()[i] - 1 - - if row_id >= 0 and col_id >= 0 and segment_id == 1: - model_labels[i] = int(prev_answers[(col_id, row_id)]) - - token_type_ids_example[:, 3] = model_labels - - input_ids_example = input_ids[index] - attention_mask_example = attention_mask[index] # shape (seq_len,) - token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) - outputs = self.model( - input_ids=np.expand_dims(input_ids_example, axis=0), - attention_mask=np.expand_dims(attention_mask_example, axis=0), - token_type_ids=np.expand_dims(token_type_ids_example, axis=0), - ) - logits = outputs.logits + input_ids_example = input_ids[index] + attention_mask_example = attention_mask[index] # shape (seq_len,) + token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) + outputs = self.model( + input_ids=input_ids_example.unsqueeze(0), + attention_mask=attention_mask_example.unsqueeze(0), + token_type_ids=token_type_ids_example.unsqueeze(0), + ) + logits = outputs.logits - if self.aggregate: - all_aggregations.append(outputs.logits_aggregation) + if self.aggregate: + all_aggregations.append(outputs.logits_aggregation) - all_logits.append(logits) + all_logits.append(logits) - probabilities = tf.math.sigmoid(tf.cast(logits, tf.float32)) * tf.cast( - attention_mask_example, tf.float32 - ) + dist_per_token = torch.distributions.Bernoulli(logits=logits) + probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to( + dist_per_token.probs.device + ) - coords_to_probs = collections.defaultdict(list) - for i, p in enumerate(tf.squeeze(probabilities).numpy().tolist()): - segment_id = token_type_ids_example[:, 0].tolist()[i] - col = token_type_ids_example[:, 1].tolist()[i] - 1 - row = token_type_ids_example[:, 2].tolist()[i] - 1 - if col >= 0 and row >= 0 and segment_id == 1: - coords_to_probs[(col, row)].append(p) + coords_to_probs = collections.defaultdict(list) + for i, p in enumerate(probabilities.squeeze().tolist()): + segment_id = token_type_ids_example[:, 0].tolist()[i] + col = token_type_ids_example[:, 1].tolist()[i] - 1 + row = token_type_ids_example[:, 2].tolist()[i] - 1 + if col >= 0 and row >= 0 and segment_id == 1: + coords_to_probs[(col, row)].append(p) - prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs} + prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs} - logits_batch = tf.concat(tuple(all_logits), 0) + logits_batch = torch.cat(tuple(all_logits), 0) - return (logits_batch,) if not self.aggregate else (logits_batch, tf.concat(tuple(all_aggregations), 0)) + return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0)) def __call__(self, *args, **kwargs): r""" @@ -393,7 +318,7 @@ def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=N raise ValueError("table is empty") if query is None or query == "": raise ValueError("query is empty") - inputs = self.tokenizer(table, query, return_tensors=self.framework, truncation=truncation, padding=padding) + inputs = self.tokenizer(table, query, return_tensors="pt", truncation=truncation, padding=padding) inputs["table"] = table return inputs diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index 8952b5820867..c9da04d37154 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -4,15 +4,10 @@ from ..generation import GenerationConfig from ..tokenization_utils import TruncationStrategy -from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging +from ..utils import add_end_docstrings, is_torch_available, logging from .base import Pipeline, build_pipeline_init_args -if is_tf_available(): - import tensorflow as tf - - from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES - if is_torch_available(): from ..models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES @@ -84,11 +79,7 @@ class Text2TextGenerationPipeline(Pipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.check_model_type( - TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES - if self.framework == "tf" - else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES - ) + self.check_model_type(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES) def _sanitize_parameters( self, @@ -153,7 +144,7 @@ def _parse_and_tokenize(self, *args, truncation): raise TypeError( f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`" ) - inputs = self.tokenizer(*args, padding=padding, truncation=truncation, return_tensors=self.framework) + inputs = self.tokenizer(*args, padding=padding, truncation=truncation, return_tensors="pt") # This is produced by tokenizers but is an invalid generate kwargs if "token_type_ids" in inputs: del inputs["token_type_ids"] @@ -184,7 +175,7 @@ def __call__(self, *args: Union[str, list[str]], **kwargs: Any) -> list[dict[str A list or a list of list of `dict`: Each result comes as a dictionary with the following keys: - **generated_text** (`str`, present when `return_text=True`) -- The generated text. - - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token + - **generated_token_ids** (`torch.Tensor`, present when `return_tensors=True`) -- The token ids of the generated text. """ @@ -202,10 +193,7 @@ def preprocess(self, inputs, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kw return inputs def _forward(self, model_inputs, **generate_kwargs): - if self.framework == "pt": - in_b, input_length = model_inputs["input_ids"].shape - elif self.framework == "tf": - in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy() + in_b, input_length = model_inputs["input_ids"].shape self.check_inputs( input_length, @@ -219,10 +207,7 @@ def _forward(self, model_inputs, **generate_kwargs): output_ids = self.model.generate(**model_inputs, **generate_kwargs) out_b = output_ids.shape[0] - if self.framework == "pt": - output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) - elif self.framework == "tf": - output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:])) + output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) return {"output_ids": output_ids} def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False): @@ -264,13 +249,10 @@ class SummarizationPipeline(Text2TextGenerationPipeline): Usage: ```python - # use bart in pytorch + # use bart summarizer = pipeline("summarization") summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20) - # use t5 in tf - summarizer = pipeline("summarization", model="google-t5/t5-base", tokenizer="google-t5/t5-base", framework="tf") - summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20) ```""" # Used in the return key of the pipeline. @@ -297,7 +279,7 @@ def __call__(self, *args, **kwargs): A list or a list of list of `dict`: Each result comes as a dictionary with the following keys: - **summary_text** (`str`, present when `return_text=True`) -- The summary of the corresponding input. - - **summary_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token + - **summary_token_ids** (`torch.Tensor`, present when `return_tensors=True`) -- The token ids of the summary. """ return super().__call__(*args, **kwargs) @@ -356,7 +338,7 @@ def check_inputs(self, input_length: int, min_length: int, max_length: int): def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None): if getattr(self.tokenizer, "_build_translation_inputs", None): return self.tokenizer._build_translation_inputs( - *args, return_tensors=self.framework, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang + *args, return_tensors="pt", truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang ) else: return super()._parse_and_tokenize(*args, truncation=truncation) @@ -404,7 +386,7 @@ def __call__(self, *args, **kwargs): A list or a list of list of `dict`: Each result comes as a dictionary with the following keys: - **translation_text** (`str`, present when `return_text=True`) -- The translation. - - **translation_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The + - **translation_token_ids** (`torch.Tensor`, present when `return_tensors=True`) -- The token ids of the translation. """ return super().__call__(*args, **kwargs) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 7d703ba50117..c77ca1d4bd37 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -4,7 +4,7 @@ from typing import Any, overload from ..generation import GenerationConfig -from ..utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available +from ..utils import ModelOutput, add_end_docstrings, is_torch_available from .base import Pipeline, build_pipeline_init_args @@ -14,11 +14,6 @@ from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from .pt_utils import KeyDataset -if is_tf_available(): - import tensorflow as tf - - from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - ChatType = list[dict[str, str]] @@ -119,9 +114,7 @@ class TextGenerationPipeline(Pipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.check_model_type( - TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - ) + self.check_model_type(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) if "prefix" not in self._preprocess_params: # This is very specific. The logic is quite complex and needs to be done # as a "default". @@ -181,7 +174,7 @@ def _sanitize_parameters( preprocess_params["prefix"] = prefix if prefix: prefix_inputs = self.tokenizer( - prefix, padding=False, add_special_tokens=add_special_tokens, return_tensors=self.framework + prefix, padding=False, add_special_tokens=add_special_tokens, return_tensors="pt" ) generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1] @@ -298,14 +291,14 @@ def __call__(self, text_inputs, **kwargs): a chat, it is passed to `apply_chat_template`. Otherwise, it is passed to `__call__`. generate_kwargs (`dict`, *optional*): Additional keyword arguments to pass along to the generate method of the model (see the generate method - corresponding to your framework [here](./text_generation)). + [here](./text_generation)). Return: A list or a list of lists of `dict`: Returns one of the following dictionaries (cannot return a combination of both `generated_text` and `generated_token_ids`): - **generated_text** (`str`, present when `return_text=True`) -- The generated text. - - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token + - **generated_token_ids** (`torch.Tensor`, present when `return_tensors=True`) -- The token ids of the generated text. """ if isinstance( @@ -365,11 +358,11 @@ def preprocess( add_generation_prompt=not continue_final_message, continue_final_message=continue_final_message, return_dict=True, - return_tensors=self.framework, + return_tensors="pt", **tokenizer_kwargs, ) else: - inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs) + inputs = self.tokenizer(prefix + prompt_text, return_tensors="pt", **tokenizer_kwargs) inputs["prompt_text"] = prompt_text @@ -436,29 +429,18 @@ def _forward(self, model_inputs, **generate_kwargs): other_outputs = {k: v for k, v in output.items() if k not in {"sequences", "past_key_values"}} out_b = generated_sequence.shape[0] - if self.framework == "pt": - for key, value in other_outputs.items(): - if isinstance(value, torch.Tensor) and value.shape[0] == out_b: - other_outputs[key] = value.reshape(in_b, out_b // in_b, *value.shape[1:]) - if isinstance(value, tuple) and len(value[0]) == out_b: - value = torch.stack(value).swapaxes(0, 1) - other_outputs[key] = value - elif self.framework == "tf": - for key, value in other_outputs.items(): - if isinstance(value, tf.Tensor) and value.shape[0] == out_b: - other_outputs[key] = tf.reshape(value, (in_b, out_b // in_b, *value.shape[1:])) - if isinstance(value, tuple) and len(value[0]) == out_b: - value = tf.stack(value).swapaxes(0, 1) - other_outputs[key] = value + for key, value in other_outputs.items(): + if isinstance(value, torch.Tensor) and value.shape[0] == out_b: + other_outputs[key] = value.reshape(in_b, out_b // in_b, *value.shape[1:]) + if isinstance(value, tuple) and len(value[0]) == out_b: + value = torch.stack(value).swapaxes(0, 1) + other_outputs[key] = value else: generated_sequence = output other_outputs = {} out_b = generated_sequence.shape[0] - if self.framework == "pt": - generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) - elif self.framework == "tf": - generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) + generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) model_outputs = { "generated_sequence": generated_sequence, @@ -485,14 +467,9 @@ def postprocess( other_outputs = model_outputs.get("additional_outputs", {}) split_keys = {} if other_outputs: - if self.framework == "pt": - for k, v in other_outputs.items(): - if isinstance(v, torch.Tensor) and v.shape[0] == len(generated_sequence): - split_keys[k] = v.numpy().tolist() - elif self.framework == "tf": - for k, v in other_outputs.items(): - if isinstance(v, tf.Tensor) and v.shape[0] == len(generated_sequence): - split_keys[k] = v.numpy().tolist() + for k, v in other_outputs.items(): + if isinstance(v, torch.Tensor) and v.shape[0] == len(generated_sequence): + split_keys[k] = v.numpy().tolist() skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True for idx, sequence in enumerate(generated_sequence): diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index efa70ca1851f..0df615edcfd3 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -8,16 +8,11 @@ from ..utils import ( ExplicitEnum, add_end_docstrings, - is_tf_available, is_torch_available, ) from .base import ArgumentHandler, ChunkPipeline, Dataset, build_pipeline_init_args -if is_tf_available(): - import tensorflow as tf - - from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES if is_torch_available(): import torch @@ -144,11 +139,7 @@ class TokenClassificationPipeline(ChunkPipeline): def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs): super().__init__(*args, **kwargs) - self.check_model_type( - TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES - if self.framework == "tf" - else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES - ) + self.check_model_type(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES) self._basic_tokenizer = BasicTokenizer(do_lower_case=False) self._args_parser = args_parser @@ -308,7 +299,7 @@ def preprocess(self, sentence, offset_mapping=None, **preprocess_params): inputs = self.tokenizer( text_to_tokenize, - return_tensors=self.framework, + return_tensors="pt", truncation=truncation, return_special_tokens_mask=True, return_offsets_mapping=self.tokenizer.is_fast, @@ -322,10 +313,7 @@ def preprocess(self, sentence, offset_mapping=None, **preprocess_params): num_chunks = len(inputs["input_ids"]) for i in range(num_chunks): - if self.framework == "tf": - model_inputs = {k: tf.expand_dims(v[i], 0) for k, v in inputs.items()} - else: - model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()} + model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()} if offset_mapping is not None: model_inputs["offset_mapping"] = offset_mapping @@ -346,11 +334,8 @@ def _forward(self, model_inputs): word_ids = model_inputs.pop("word_ids", None) word_to_chars_map = model_inputs.pop("word_to_chars_map", None) - if self.framework == "tf": - logits = self.model(**model_inputs)[0] - else: - output = self.model(**model_inputs) - logits = output["logits"] if isinstance(output, dict) else output[0] + output = self.model(**model_inputs) + logits = output["logits"] if isinstance(output, dict) else output[0] return { "logits": logits, @@ -372,7 +357,7 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE word_to_chars_map = all_outputs[0].get("word_to_chars_map") for model_outputs in all_outputs: - if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16): + if model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16): logits = model_outputs["logits"][0].to(torch.float32).numpy() else: logits = model_outputs["logits"][0].numpy() @@ -389,10 +374,6 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE shifted_exp = np.exp(logits - maxes) scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True) - if self.framework == "tf": - input_ids = input_ids.numpy() - offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None - pre_entities = self.gather_pre_entities( sentence, input_ids, @@ -470,9 +451,8 @@ def gather_pre_entities( end_ind += start_char if not isinstance(start_ind, int): - if self.framework == "pt": - start_ind = start_ind.item() - end_ind = end_ind.item() + start_ind = start_ind.item() + end_ind = end_ind.item() word_ref = sentence[start_ind:end_ind] if getattr(self.tokenizer, "_tokenizer", None) and getattr( self.tokenizer._tokenizer.model, "continuing_subword_prefix", None diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index d8ec62124556..57b7b118f27a 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -91,7 +91,6 @@ is_fbgemm_gpu_available, is_flash_attn_2_available, is_flash_attn_3_available, - is_flax_available, is_flute_available, is_fp_quant_available, is_fsdp_available, @@ -107,7 +106,6 @@ is_ipex_available, is_jinja_available, is_jumanpp_available, - is_keras_nlp_available, is_kernels_available, is_levenshtein_available, is_librosa_available, @@ -143,7 +141,6 @@ is_spqr_available, is_sudachi_available, is_sudachi_projection_available, - is_tf_available, is_tiktoken_available, is_timm_available, is_tokenizers_available, @@ -680,18 +677,6 @@ def require_torchcodec(test_case): return unittest.skipUnless(is_torchcodec_available(), "test requires Torchcodec")(test_case) -def require_torch_or_tf(test_case): - """ - Decorator marking a test that requires PyTorch or TensorFlow. - - These tests are skipped when neither PyTorch not TensorFlow is installed. - - """ - return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")( - test_case - ) - - def require_intel_extension_for_pytorch(test_case): """ Decorator marking a test that requires Intel Extension for PyTorch. @@ -749,13 +734,6 @@ def require_tokenizers(test_case): return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case) -def require_keras_nlp(test_case): - """ - Decorator marking a test that requires keras_nlp. These tests are skipped when keras_nlp isn't installed. - """ - return unittest.skipUnless(is_keras_nlp_available(), "test requires keras_nlp")(test_case) - - def require_pandas(test_case): """ Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed. @@ -1023,16 +1001,6 @@ def require_torch_multi_hpu(test_case): else: torch_device = None -if is_tf_available(): - import tensorflow as tf - -if is_flax_available(): - import jax - - jax_device = jax.default_backend() -else: - jax_device = None - def require_torchdynamo(test_case): """Decorator marking a test that requires TorchDynamo""" @@ -1545,20 +1513,12 @@ def require_mistral_common(test_case): def get_gpu_count(): """ - Return the number of available gpus (regardless of whether torch, tf or jax is used) + Return the number of available gpus """ if is_torch_available(): import torch return torch.cuda.device_count() - elif is_tf_available(): - import tensorflow as tf - - return len(tf.config.list_physical_devices("GPU")) - elif is_flax_available(): - import jax - - return jax.device_count() else: return 0 @@ -2581,8 +2541,6 @@ def nested_simplify(obj, decimals=3): return obj elif is_torch_available() and isinstance(obj, torch.Tensor): return nested_simplify(obj.tolist(), decimals) - elif is_tf_available() and tf.is_tensor(obj): - return nested_simplify(obj.numpy().tolist()) elif isinstance(obj, float): return round(obj, decimals) elif isinstance(obj, (np.int32, np.float32, np.float16)): diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index f88711fdb655..36a99d66e23d 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -47,15 +47,11 @@ copy_func, download_url, extract_commit_hash, - is_flax_available, - is_jax_tensor, is_mlx_available, is_numpy_array, is_offline_mode, is_protobuf_available, is_remote_url, - is_tf_available, - is_tf_tensor, is_tokenizers_available, is_torch_available, is_torch_device, @@ -72,10 +68,6 @@ if TYPE_CHECKING: if is_torch_available(): import torch - if is_tf_available(): - import tensorflow as tf - if is_flax_available(): - import jax.numpy as jnp # noqa: F401 def import_protobuf_decode_error(error_message=""): @@ -214,13 +206,13 @@ class BatchEncoding(UserDict): space to token space the `tokenizers.Encoding` instance or list of instance (for batches) hold this information. tensor_type (`Union[None, str, TensorType]`, *optional*): - You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at initialization. prepend_batch_axis (`bool`, *optional*, defaults to `False`): Whether or not to add a batch axis when converting to tensors (see `tensor_type` above). Note that this parameter has an effect if the parameter `tensor_type` is set, *otherwise has no effect*. n_sequences (`Optional[int]`, *optional*): - You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at initialization. """ @@ -714,22 +706,7 @@ def convert_to_tensors( if not isinstance(tensor_type, TensorType): tensor_type = TensorType(tensor_type) - # Get a function reference for the correct framework - if tensor_type == TensorType.TENSORFLOW: - if not is_tf_available(): - raise ImportError( - "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." - ) - import tensorflow as tf - - def as_tensor(value, dtype=None): - if len(flatten(value)) == 0 and dtype is None: - dtype = tf.int32 - return tf.constant(value, dtype=dtype) - - is_tensor = tf.is_tensor - - elif tensor_type == TensorType.PYTORCH: + if tensor_type == TensorType.PYTORCH: if not is_torch_available(): raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") import torch @@ -743,18 +720,6 @@ def as_tensor(value, dtype=None): is_tensor = torch.is_tensor - elif tensor_type == TensorType.JAX: - if not is_flax_available(): - raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") - import jax.numpy as jnp # noqa: F811 - - def as_tensor(value, dtype=None): - if len(flatten(value)) == 0 and dtype is None: - dtype = jnp.int32 - return jnp.array(value, dtype=dtype) - - is_tensor = is_jax_tensor - elif tensor_type == TensorType.MLX: if not is_mlx_available(): raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.") @@ -1269,7 +1234,6 @@ def _set_model_specific_special_tokens(self, special_tokens: list[str]): return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. """ @@ -1614,10 +1578,8 @@ def apply_chat_template( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.Tensor` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. return_dict (`bool`, defaults to `False`): Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. tokenizer_kwargs (`dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. @@ -2700,7 +2662,7 @@ def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bo """, """ Returns: - `list[int]`, `torch.Tensor`, `tf.Tensor` or `np.ndarray`: The tokenized ids of the text. + `list[int]`, `torch.Tensor`, or `np.ndarray`: The tokenized ids of the text. """, ) def encode( @@ -2924,11 +2886,6 @@ def __call__( "verbose": verbose, } - if return_tensors in ("tf", "jax"): - logger.warning_once( - "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " - "recommend migrating to PyTorch classes or pinning your version of Transformers." - ) all_kwargs.update(kwargs) if text is None and text_target is None: raise ValueError("You need to specify either `text` or `text_target`.") @@ -3308,7 +3265,7 @@ def pad( - If the `encoded_inputs` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the + If the `encoded_inputs` passed are dictionary of numpy arrays, or PyTorch tensors, the result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of PyTorch tensors, you will lose the specific device of your tensors however. @@ -3321,7 +3278,7 @@ def pad( list[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader collate function. - Instead of `list[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see + Instead of `list[int]` you can have tensors (numpy arrays, or PyTorch tensors), see the note above for the return type. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding @@ -3351,7 +3308,6 @@ def pad( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. verbose (`bool`, *optional*, defaults to `True`): @@ -3385,7 +3341,7 @@ def pad( encoded_inputs["attention_mask"] = [] return encoded_inputs - # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects + # If we have PyTorch/NumPy tensors/arrays as inputs, we cast them as python objects # and rebuild them afterwards if no return_tensors is specified # Note that we lose the specific device the tensor may be on for PyTorch @@ -3398,16 +3354,14 @@ def pad( break # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. if not isinstance(first_element, (int, list, tuple)): - if is_tf_tensor(first_element): - return_tensors = "tf" if return_tensors is None else return_tensors - elif is_torch_tensor(first_element): + if is_torch_tensor(first_element): return_tensors = "pt" if return_tensors is None else return_tensors elif isinstance(first_element, np.ndarray): return_tensors = "np" if return_tensors is None else return_tensors else: raise ValueError( f"type of {first_element} unknown: {type(first_element)}. " - "Should be one of a python, numpy, pytorch or tensorflow object." + "Should be one of a python, numpy, or pytorch object." ) for key, value in encoded_inputs.items(): @@ -3861,7 +3815,7 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str: def batch_decode( self, - sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], + sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = None, **kwargs, @@ -3870,7 +3824,7 @@ def batch_decode( Convert a list of lists of token ids into a list of strings by calling decode. Args: - sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. @@ -3895,7 +3849,7 @@ def batch_decode( def decode( self, - token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = None, **kwargs, @@ -3907,7 +3861,7 @@ def decode( Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. Args: - token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. @@ -4105,7 +4059,6 @@ def prepare_seq2seq_batch( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `True`): diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index e2a382db6c91..a2d84b024057 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -31,7 +31,6 @@ from .utils import ( ExplicitEnum, is_psutil_available, - is_tf_available, is_torch_available, is_torch_cuda_available, is_torch_hpu_available, @@ -61,8 +60,7 @@ def seed_worker(worker_id: int, num_workers: int, rank: int): def enable_full_determinism(seed: int, warn_only: bool = False): """ Helper function for reproducible behavior during distributed training. See - - https://pytorch.org/docs/stable/notes/randomness.html for pytorch - - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow + https://pytorch.org/docs/stable/notes/randomness.html for pytorch """ # set seed first set_seed(seed) @@ -84,15 +82,10 @@ def enable_full_determinism(seed: int, warn_only: bool = False): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - if is_tf_available(): - import tensorflow as tf - - tf.config.experimental.enable_op_determinism() - def set_seed(seed: int, deterministic: bool = False): """ - Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed). + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` (if installed). Args: seed (`int`): @@ -118,12 +111,6 @@ def set_seed(seed: int, deterministic: bool = False): torch.hpu.manual_seed_all(seed) if is_torch_xpu_available(): torch.xpu.manual_seed_all(seed) - if is_tf_available(): - import tensorflow as tf - - tf.random.set_seed(seed) - if deterministic: - tf.config.experimental.enable_op_determinism() def neftune_post_forward_hook(module, input, output): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 1b4671a55e8c..b87223ad99d0 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -57,12 +57,8 @@ filter_out_non_signature_kwargs, find_labels, flatten_dict, - infer_framework, - is_jax_tensor, is_numpy_array, is_tensor, - is_tf_symbolic_tensor, - is_tf_tensor, is_timm_config_dict, is_timm_local_checkpoint, is_torch_device, diff --git a/src/transformers/utils/doc.py b/src/transformers/utils/doc.py index 6488c6d16bdd..f9a787a74a13 100644 --- a/src/transformers/utils/doc.py +++ b/src/transformers/utils/doc.py @@ -95,15 +95,6 @@ def docstring_decorator(fn): """ -TF_RETURN_INTRODUCTION = r""" - Returns: - [`{full_output_type}`] or `tuple(tf.Tensor)`: A [`{full_output_type}`] or a tuple of `tf.Tensor` (if - `return_dict=False` is passed or when `config.return_dict=False`) comprising various elements depending on the - configuration ([`{config_class}`]) and inputs. - -""" - - def _get_indent(t): """Returns the indentation in the first line of t""" search = re.search(r"^(\s*)\S", t) @@ -160,8 +151,7 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None, add_i # Add the return introduction if add_intro: full_output_type = f"{output_type.__module__}.{output_type.__name__}" - intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION - intro = intro.format(full_output_type=full_output_type, config_class=config_class) + intro = PT_RETURN_INTRODUCTION.format(full_output_type=full_output_type, config_class=config_class) else: full_output_type = str(output_type) intro = f"\nReturns:\n `{full_output_type}`" @@ -999,445 +989,6 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None, add_i ) -TF_TOKEN_CLASSIFICATION_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - >>> import tensorflow as tf - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = tokenizer( - ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="tf" - ... ) - - >>> logits = model(**inputs).logits - >>> predicted_token_class_ids = tf.math.argmax(logits, axis=-1) - - >>> # Note that tokens are classified rather then input words which means that - >>> # there might be more predicted token classes than words. - >>> # Multiple token classes might account for the same word - >>> predicted_tokens_classes = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()] - >>> predicted_tokens_classes - {expected_output} - ``` - - ```python - >>> labels = predicted_token_class_ids - >>> loss = tf.math.reduce_mean(model(**inputs, labels=labels).loss) - >>> round(float(loss), 2) - {expected_loss} - ``` -""" - -TF_QUESTION_ANSWERING_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - >>> import tensorflow as tf - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" - - >>> inputs = tokenizer(question, text, return_tensors="tf") - >>> outputs = model(**inputs) - - >>> answer_start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0]) - >>> answer_end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0]) - - >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] - >>> tokenizer.decode(predict_answer_tokens) - {expected_output} - ``` - - ```python - >>> # target is "nice puppet" - >>> target_start_index = tf.constant([{qa_target_start_index}]) - >>> target_end_index = tf.constant([{qa_target_end_index}]) - - >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index) - >>> loss = tf.math.reduce_mean(outputs.loss) - >>> round(float(loss), 2) - {expected_loss} - ``` -""" - -TF_SEQUENCE_CLASSIFICATION_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - >>> import tensorflow as tf - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") - - >>> logits = model(**inputs).logits - - >>> predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0]) - >>> model.config.id2label[predicted_class_id] - {expected_output} - ``` - - ```python - >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` - >>> num_labels = len(model.config.id2label) - >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels) - - >>> labels = tf.constant(1) - >>> loss = model(**inputs, labels=labels).loss - >>> round(float(loss), 2) - {expected_loss} - ``` -""" - -TF_MASKED_LM_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - >>> import tensorflow as tf - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf") - >>> logits = model(**inputs).logits - - >>> # retrieve index of {mask} - >>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0]) - >>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index) - - >>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1) - >>> tokenizer.decode(predicted_token_id) - {expected_output} - ``` - - ```python - >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"] - >>> # mask labels of non-{mask} tokens - >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) - - >>> outputs = model(**inputs, labels=labels) - >>> round(float(outputs.loss), 2) - {expected_loss} - ``` -""" - -TF_BASE_MODEL_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - >>> import tensorflow as tf - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") - >>> outputs = model(inputs) - - >>> last_hidden_states = outputs.last_hidden_state - ``` -""" - -TF_MULTIPLE_CHOICE_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - >>> import tensorflow as tf - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - >>> choice0 = "It is eaten with a fork and a knife." - >>> choice1 = "It is eaten while held in the hand." - - >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="tf", padding=True) - >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}} - >>> outputs = model(inputs) # batch size is 1 - - >>> # the linear classifier still needs to be trained - >>> logits = outputs.logits - ``` -""" - -TF_CAUSAL_LM_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - >>> import tensorflow as tf - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") - >>> outputs = model(inputs) - >>> logits = outputs.logits - ``` -""" - -TF_SPEECH_BASE_MODEL_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoProcessor, {model_class} - >>> from datasets import load_dataset - - >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") - >>> dataset = dataset.sort("id") - >>> sampling_rate = dataset.features["audio"].sampling_rate - - >>> processor = AutoProcessor.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> # audio file is decoded on the fly - >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf") - >>> outputs = model(**inputs) - - >>> last_hidden_states = outputs.last_hidden_state - >>> list(last_hidden_states.shape) - {expected_output} - ``` -""" - -TF_SPEECH_CTC_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoProcessor, {model_class} - >>> from datasets import load_dataset - >>> import tensorflow as tf - - >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") - >>> dataset = dataset.sort("id") - >>> sampling_rate = dataset.features["audio"].sampling_rate - - >>> processor = AutoProcessor.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> # audio file is decoded on the fly - >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf") - >>> logits = model(**inputs).logits - >>> predicted_ids = tf.math.argmax(logits, axis=-1) - - >>> # transcribe speech - >>> transcription = processor.batch_decode(predicted_ids) - >>> transcription[0] - {expected_output} - ``` - - ```python - >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids - - >>> # compute loss - >>> loss = model(**inputs).loss - >>> round(float(loss), 2) - {expected_loss} - ``` -""" - -TF_VISION_BASE_MODEL_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoImageProcessor, {model_class} - >>> from datasets import load_dataset - - >>> dataset = load_dataset("huggingface/cats-image") - >>> image = dataset["test"]["image"][0] - - >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = image_processor(image, return_tensors="tf") - >>> outputs = model(**inputs) - - >>> last_hidden_states = outputs.last_hidden_state - >>> list(last_hidden_states.shape) - {expected_output} - ``` -""" - -TF_VISION_SEQ_CLASS_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoImageProcessor, {model_class} - >>> import tensorflow as tf - >>> from datasets import load_dataset - - >>> dataset = load_dataset("huggingface/cats-image")) - >>> image = dataset["test"]["image"][0] - - >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = image_processor(image, return_tensors="tf") - >>> logits = model(**inputs).logits - - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_label = int(tf.math.argmax(logits, axis=-1)) - >>> print(model.config.id2label[predicted_label]) - {expected_output} - ``` -""" - -TF_SAMPLE_DOCSTRINGS = { - "SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE, - "QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE, - "TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE, - "MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE, - "MaskedLM": TF_MASKED_LM_SAMPLE, - "LMHead": TF_CAUSAL_LM_SAMPLE, - "BaseModel": TF_BASE_MODEL_SAMPLE, - "SpeechBaseModel": TF_SPEECH_BASE_MODEL_SAMPLE, - "CTC": TF_SPEECH_CTC_SAMPLE, - "VisionBaseModel": TF_VISION_BASE_MODEL_SAMPLE, - "ImageClassification": TF_VISION_SEQ_CLASS_SAMPLE, -} - - -FLAX_TOKEN_CLASSIFICATION_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") - - >>> outputs = model(**inputs) - >>> logits = outputs.logits - ``` -""" - -FLAX_QUESTION_ANSWERING_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" - >>> inputs = tokenizer(question, text, return_tensors="jax") - - >>> outputs = model(**inputs) - >>> start_scores = outputs.start_logits - >>> end_scores = outputs.end_logits - ``` -""" - -FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") - - >>> outputs = model(**inputs) - >>> logits = outputs.logits - ``` -""" - -FLAX_MASKED_LM_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="jax") - - >>> outputs = model(**inputs) - >>> logits = outputs.logits - ``` -""" - -FLAX_BASE_MODEL_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") - >>> outputs = model(**inputs) - - >>> last_hidden_states = outputs.last_hidden_state - ``` -""" - -FLAX_MULTIPLE_CHOICE_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - >>> choice0 = "It is eaten with a fork and a knife." - >>> choice1 = "It is eaten while held in the hand." - - >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="jax", padding=True) - >>> outputs = model(**{{k: v[None, :] for k, v in encoding.items()}}) - - >>> logits = outputs.logits - ``` -""" - -FLAX_CAUSAL_LM_SAMPLE = r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, {model_class} - - >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") - >>> outputs = model(**inputs) - - >>> # retrieve logts for next token - >>> next_token_logits = outputs.logits[:, -1] - ``` -""" - -FLAX_SAMPLE_DOCSTRINGS = { - "SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE, - "QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE, - "TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE, - "MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE, - "MaskedLM": FLAX_MASKED_LM_SAMPLE, - "BaseModel": FLAX_BASE_MODEL_SAMPLE, - "LMHead": FLAX_CAUSAL_LM_SAMPLE, -} - - def filter_outputs_from_example(docstring, **kwargs): """ Removes the lines testing an output with the doctest syntax in a code sample when it's set to `None`. @@ -1472,12 +1023,7 @@ def docstring_decorator(fn): # model_class defaults to function's class if not specified otherwise model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls - if model_class[:2] == "TF": - sample_docstrings = TF_SAMPLE_DOCSTRINGS - elif model_class[:4] == "Flax": - sample_docstrings = FLAX_SAMPLE_DOCSTRINGS - else: - sample_docstrings = PT_SAMPLE_DOCSTRINGS + sample_docstrings = PT_SAMPLE_DOCSTRINGS # putting all kwargs for docstrings in a dict to be used # with the `.format(**doc_kwargs)`. Note that string might diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 94d842eee826..cc2f8b5f7046 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -32,9 +32,7 @@ from ..utils import logging from .import_utils import ( - is_flax_available, is_mlx_available, - is_tf_available, is_torch_available, is_torch_fx_proxy, requires, @@ -76,10 +74,6 @@ def infer_framework_from_repr(x): representation = str(type(x)) if representation.startswith(" Date: Tue, 9 Sep 2025 16:50:44 +0200 Subject: [PATCH 04/35] more and more --- src/transformers/__init__.py | 24 --- src/transformers/data/data_collator.py | 8 - .../feature_extraction_sequence_utils.py | 9 +- src/transformers/file_utils.py | 4 - src/transformers/generation/__init__.py | 3 +- .../integrations/integration_utils.py | 9 - .../image_processing_conditional_detr.py | 2 - .../models/data2vec/modeling_data2vec_text.py | 2 +- .../modeling_decision_transformer.py | 1 - .../models/luke/tokenization_luke.py | 9 +- .../models/mluke/tokenization_mluke.py | 9 +- src/transformers/models/mt5/modeling_mt5.py | 1 - .../models/openai/modeling_openai.py | 82 -------- .../models/roformer/modeling_roformer.py | 5 +- src/transformers/models/t5/modeling_t5.py | 2 - src/transformers/onnx/convert.py | 122 ++---------- src/transformers/onnx/features.py | 138 ++----------- src/transformers/pipelines/__init__.py | 1 - .../pipelines/image_classification.py | 17 +- src/transformers/pipelines/image_to_text.py | 34 +--- .../pipelines/text_classification.py | 22 +-- .../zero_shot_image_classification.py | 26 +-- src/transformers/utils/__init__.py | 9 - src/transformers/utils/hub.py | 6 +- src/transformers/utils/import_utils.py | 185 +---------------- tests/models/byt5/test_tokenization_byt5.py | 20 +- .../test_tokenization_layoutlmv2.py | 23 +-- .../test_tokenization_layoutlmv3.py | 23 +-- .../layoutxlm/test_tokenization_layoutxlm.py | 23 +-- .../models/marian/test_tokenization_marian.py | 13 +- .../markuplm/test_tokenization_markuplm.py | 23 +-- tests/models/myt5/test_tokenization_myt5.py | 9 - .../pegasus/test_tokenization_pegasus.py | 19 -- .../perceiver/test_tokenization_perceiver.py | 20 +- tests/models/sam2/test_processor_sam2.py | 5 +- .../sam2_video/test_processor_sam2_video.py | 5 +- .../models/siglip/test_tokenization_siglip.py | 22 +-- .../splinter/test_tokenization_splinter.py | 12 +- tests/models/t5/test_tokenization_t5.py | 21 +- tests/models/udop/test_tokenization_udop.py | 23 +-- tests/pipelines/test_pipelines_common.py | 29 +-- .../test_pipelines_mask_generation.py | 9 - utils/check_repo.py | 186 +++--------------- utils/create_dummy_models.py | 100 +--------- 44 files changed, 149 insertions(+), 1166 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d52faa02c86c..b4df2baa235e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -40,12 +40,9 @@ # so that mypy, pylint or other static linters can recognize them, # given that they are not exported using `__all__` in this file. from .utils import is_bitsandbytes_available as is_bitsandbytes_available -from .utils import is_flax_available as is_flax_available -from .utils import is_keras_nlp_available as is_keras_nlp_available from .utils import is_scipy_available as is_scipy_available from .utils import is_sentencepiece_available as is_sentencepiece_available from .utils import is_speech_available as is_speech_available -from .utils import is_tensorflow_text_available as is_tensorflow_text_available from .utils import is_timm_available as is_timm_available from .utils import is_tokenizers_available as is_tokenizers_available from .utils import is_torch_available as is_torch_available @@ -224,8 +221,6 @@ "is_bitsandbytes_available", "is_datasets_available", "is_faiss_available", - "is_flax_available", - "is_keras_nlp_available", "is_matplotlib_available", "is_mlx_available", "is_phonemizer_available", @@ -238,7 +233,6 @@ "is_sentencepiece_available", "is_sklearn_available", "is_speech_available", - "is_tensorflow_text_available", "is_timm_available", "is_tokenizers_available", "is_torch_available", @@ -580,20 +574,6 @@ from .generation import EpsilonLogitsWarper as EpsilonLogitsWarper from .generation import EtaLogitsWarper as EtaLogitsWarper from .generation import ExponentialDecayLengthPenalty as ExponentialDecayLengthPenalty - from .generation import FlaxForcedBOSTokenLogitsProcessor as FlaxForcedBOSTokenLogitsProcessor - from .generation import FlaxForcedEOSTokenLogitsProcessor as FlaxForcedEOSTokenLogitsProcessor - from .generation import FlaxForceTokensLogitsProcessor as FlaxForceTokensLogitsProcessor - from .generation import FlaxGenerationMixin as FlaxGenerationMixin - from .generation import FlaxLogitsProcessor as FlaxLogitsProcessor - from .generation import FlaxLogitsProcessorList as FlaxLogitsProcessorList - from .generation import FlaxLogitsWarper as FlaxLogitsWarper - from .generation import FlaxMinLengthLogitsProcessor as FlaxMinLengthLogitsProcessor - from .generation import FlaxSuppressTokensAtBeginLogitsProcessor as FlaxSuppressTokensAtBeginLogitsProcessor - from .generation import FlaxSuppressTokensLogitsProcessor as FlaxSuppressTokensLogitsProcessor - from .generation import FlaxTemperatureLogitsWarper as FlaxTemperatureLogitsWarper - from .generation import FlaxTopKLogitsWarper as FlaxTopKLogitsWarper - from .generation import FlaxTopPLogitsWarper as FlaxTopPLogitsWarper - from .generation import FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor from .generation import ForcedBOSTokenLogitsProcessor as ForcedBOSTokenLogitsProcessor from .generation import ForcedEOSTokenLogitsProcessor as ForcedEOSTokenLogitsProcessor from .generation import GenerationConfig as GenerationConfig @@ -655,18 +635,14 @@ from .integrations import is_wandb_available as is_wandb_available from .integrations.executorch import TorchExportableModuleWithStaticCache as TorchExportableModuleWithStaticCache from .integrations.executorch import convert_and_export_with_cache as convert_and_export_with_cache - from .keras_callbacks import KerasMetricCallback as KerasMetricCallback - from .keras_callbacks import PushToHubCallback as PushToHubCallback from .masking_utils import AttentionMaskInterface as AttentionMaskInterface from .model_debugging_utils import model_addition_debugger_context as model_addition_debugger_context # Model Cards from .modelcard import ModelCard as ModelCard - from .modeling_flax_utils import FlaxPreTrainedModel as FlaxPreTrainedModel from .modeling_layers import GradientCheckpointingLayer as GradientCheckpointingLayer from .modeling_rope_utils import ROPE_INIT_FUNCTIONS as ROPE_INIT_FUNCTIONS from .modeling_rope_utils import dynamic_rope_update as dynamic_rope_update - from .modeling_utils import AttentionInterface as AttentionInterface from .modeling_utils import PreTrainedModel as PreTrainedModel from .models import * diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index d3bf1c6cfdcb..ff9bf6d7cd69 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -685,7 +685,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): mask_replace_prob: float = 0.8 random_replace_prob: float = 0.1 pad_to_multiple_of: Optional[int] = None - tf_experimental_compile: bool = False return_tensors: str = "pt" seed: Optional[int] = None @@ -1293,13 +1292,6 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch) return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} - def tf_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: - if isinstance(examples[0], Mapping): - examples = [e["input_ids"] for e in examples] - batch = _tf_collate_batch(examples, self.tokenizer) - inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch) - return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} - def numpy_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: if isinstance(examples[0], Mapping): examples = [e["input_ids"] for e in examples] diff --git a/src/transformers/feature_extraction_sequence_utils.py b/src/transformers/feature_extraction_sequence_utils.py index e0be17bd7d28..b5c40ca44f1b 100644 --- a/src/transformers/feature_extraction_sequence_utils.py +++ b/src/transformers/feature_extraction_sequence_utils.py @@ -20,7 +20,7 @@ import numpy as np from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin -from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy +from .utils import PaddingStrategy, TensorType, is_torch_tensor, logging, to_numpy logger = logging.get_logger(__name__) @@ -116,7 +116,6 @@ def pad( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. """ @@ -145,7 +144,7 @@ def pad( processed_features["attention_mask"] = [] return processed_features - # If we have PyTorch/TF tensors or lists as inputs, we cast them as Numpy arrays + # If we have PyTorch tensors or lists as inputs, we cast them as Numpy arrays # and rebuild them afterwards if no return_tensors is specified # Note that we lose the specific device the tensor may be on for PyTorch @@ -159,9 +158,7 @@ def pad( first_element = required_input[index][0] if return_tensors is None: - if is_tf_tensor(first_element): - return_tensors = "tf" - elif is_torch_tensor(first_element): + if is_torch_tensor(first_element): return_tensors = "pt" elif isinstance(first_element, (int, float, list, tuple, np.ndarray)): return_tensors = "np" diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index ae214c74f37d..65137cf0634f 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -47,9 +47,6 @@ TORCH_FX_REQUIRED_VERSION, TRANSFORMERS_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, - USE_JAX, - USE_TF, - USE_TORCH, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ContextManagers, @@ -79,7 +76,6 @@ is_datasets_available, is_detectron2_available, is_faiss_available, - is_flax_available, is_ftfy_available, is_g2p_en_available, is_in_notebook, diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 673fdae99718..f27cef41afd6 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING -from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available +from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available _import_structure = { @@ -124,7 +124,6 @@ ] - if TYPE_CHECKING: from .configuration_utils import ( BaseWatermarkingConfig, diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 5ef1123b8fce..3d37b3396485 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -48,7 +48,6 @@ flatten_dict, is_datasets_available, is_pandas_available, - is_tf_available, is_torch_available, logging, ) @@ -56,8 +55,6 @@ logger = logging.get_logger(__name__) -if is_tf_available(): - from .. import TFPreTrainedModel if is_torch_available(): import torch @@ -760,12 +757,6 @@ def save_model_architecture_to_file(model: Any, output_dir: str): with open(f"{output_dir}/model_architecture.txt", "w+") as f: if isinstance(model, PreTrainedModel): print(model, file=f) - elif is_tf_available() and isinstance(model, TFPreTrainedModel): - - def print_to_file(s): - print(s, file=f) - - model.summary(print_fn=print_to_file) elif is_torch_available() and ( isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model") ): diff --git a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py index d327d3b8489e..e68cab454929 100644 --- a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py +++ b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py @@ -56,8 +56,6 @@ from ...utils import ( TensorType, is_scipy_available, - is_tf_available, - is_tf_tensor, is_torch_available, is_torch_tensor, is_vision_available, diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 7f6843e98ca9..06da51b8f1de 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -49,7 +49,7 @@ # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Data2VecText -class Data2VecTextEmbeddings(nn.Module): +class Data2VecTextForTextEmbeddings(nn.Module): """ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. """ diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index ebf6c28c8765..9aa7860b7d6e 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -15,7 +15,6 @@ """PyTorch DecisionTransformer model.""" import math -import os from dataclasses import dataclass from typing import Callable, Optional, Union diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py index efbc757e8630..6838b9c5cb75 100644 --- a/src/transformers/models/luke/tokenization_luke.py +++ b/src/transformers/models/luke/tokenization_luke.py @@ -37,7 +37,7 @@ TruncationStrategy, to_py_obj, ) -from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging +from ...utils import add_end_docstrings, is_torch_tensor, logging logger = logging.get_logger(__name__) @@ -1441,7 +1441,6 @@ def pad( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. verbose (`bool`, *optional*, defaults to `True`): @@ -1466,7 +1465,7 @@ def pad( encoded_inputs["attention_mask"] = [] return encoded_inputs - # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects + # If we have PyTorch/NumPy tensors/arrays as inputs, we cast them as python objects # and rebuild them afterwards if no return_tensors is specified # Note that we lose the specific device the tensor may be on for PyTorch @@ -1480,9 +1479,7 @@ def pad( first_element = required_input[index][0] # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. if not isinstance(first_element, (int, list, tuple)): - if is_tf_tensor(first_element): - return_tensors = "tf" if return_tensors is None else return_tensors - elif is_torch_tensor(first_element): + if is_torch_tensor(first_element): return_tensors = "pt" if return_tensors is None else return_tensors elif isinstance(first_element, np.ndarray): return_tensors = "np" if return_tensors is None else return_tensors diff --git a/src/transformers/models/mluke/tokenization_mluke.py b/src/transformers/models/mluke/tokenization_mluke.py index 15f4db53287a..5c2e8c806da8 100644 --- a/src/transformers/models/mluke/tokenization_mluke.py +++ b/src/transformers/models/mluke/tokenization_mluke.py @@ -37,7 +37,7 @@ TruncationStrategy, to_py_obj, ) -from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging +from ...utils import add_end_docstrings, is_torch_tensor, logging from ...utils.import_utils import requires @@ -1279,7 +1279,6 @@ def pad( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. verbose (`bool`, *optional*, defaults to `True`): @@ -1304,7 +1303,7 @@ def pad( encoded_inputs["attention_mask"] = [] return encoded_inputs - # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects + # If we have PyTorch/NumPy tensors/arrays as inputs, we cast them as python objects # and rebuild them afterwards if no return_tensors is specified # Note that we lose the specific device the tensor may be on for PyTorch @@ -1318,9 +1317,7 @@ def pad( first_element = required_input[index][0] # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. if not isinstance(first_element, (int, list, tuple)): - if is_tf_tensor(first_element): - return_tensors = "tf" if return_tensors is None else return_tensors - elif is_torch_tensor(first_element): + if is_torch_tensor(first_element): return_tensors = "pt" if return_tensors is None else return_tensors elif isinstance(first_element, np.ndarray): return_tensors = "np" if return_tensors is None else return_tensors diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index ca3851218ccf..69124c13a115 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -652,7 +652,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel with T5->MT5, t5->mt5 class MT5PreTrainedModel(PreTrainedModel): config: MT5Config - load_tf_weights = load_tf_weights_in_mt5 base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 44fa05227ff8..d651f04a1bd1 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -15,9 +15,7 @@ # limitations under the License. """PyTorch OpenAI GPT model.""" -import json import math -import os from dataclasses import dataclass from typing import Any, Callable, Optional, Union @@ -41,84 +39,6 @@ logger = logging.get_logger(__name__) -def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): - """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)""" - import re - - import numpy as np - - if ".ckpt" in openai_checkpoint_folder_path: - openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path) - - logger.info(f"Loading weights from {openai_checkpoint_folder_path}") - - with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle: - names = json.load(names_handle) - with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle: - shapes = json.load(shapes_handle) - offsets = np.cumsum([np.prod(shape) for shape in shapes]) - init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)] - init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] - init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] - - # This was used when we had a single embedding matrix for positions and tokens - # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0) - # del init_params[1] - init_params = [arr.squeeze() for arr in init_params] - - # Check that the token and position embeddings weight dimensions map those of the init parameters. - if model.tokens_embed.weight.shape != init_params[1].shape: - raise ValueError( - f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:" - f" {init_params[1].shape}" - ) - - if model.positions_embed.weight.shape != init_params[0].shape: - raise ValueError( - f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:" - f" {init_params[0].shape}" - ) - - model.tokens_embed.weight.data = torch.from_numpy(init_params[1]) - model.positions_embed.weight.data = torch.from_numpy(init_params[0]) - names.pop(0) - # Pop position and token embedding arrays - init_params.pop(0) - init_params.pop(0) - - for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]): - name = name[6:] # skip "model/" - if name[-2:] != ":0": - raise ValueError(f"Layer {name} does not end with :0") - name = name[:-2] - name = name.split("/") - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+\d+", m_name): - scope_names = re.split(r"(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "g": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "b": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "w": - pointer = getattr(pointer, "weight") - else: - pointer = getattr(pointer, scope_names[0]) - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - - # Ensure that the pointer and array have compatible shapes. - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - ACT_FNS = {"relu": nn.ReLU(), "silu": silu, "gelu": gelu_new, "swish": silu} @@ -359,7 +279,6 @@ def forward( @auto_docstring class OpenAIGPTPreTrainedModel(PreTrainedModel): config: OpenAIGPTConfig - load_tf_weights = load_tf_weights_in_openai_gpt base_model_prefix = "transformer" def _init_weights(self, module): @@ -849,5 +768,4 @@ def forward( "OpenAIGPTLMHeadModel", "OpenAIGPTModel", "OpenAIGPTPreTrainedModel", - "load_tf_weights_in_openai_gpt", ] diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index c9f59d1a0575..26eca95b456a 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -88,8 +88,7 @@ def __init__(self, config): self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -687,7 +686,6 @@ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: @auto_docstring class RoFormerPreTrainedModel(PreTrainedModel): config: RoFormerConfig - load_tf_weights = load_tf_weights_in_roformer base_model_prefix = "roformer" supports_gradient_checkpointing = True @@ -1425,5 +1423,4 @@ def forward( "RoFormerLayer", "RoFormerModel", "RoFormerPreTrainedModel", - "load_tf_weights_in_roformer", ] diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index b7629c22933a..d5f09575e66d 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -652,7 +652,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @auto_docstring class T5PreTrainedModel(PreTrainedModel): config: T5Config - load_tf_weights = load_tf_weights_in_t5 base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True @@ -2279,7 +2278,6 @@ def forward( "T5ForConditionalGeneration", "T5Model", "T5PreTrainedModel", - "load_tf_weights_in_t5", "T5ForQuestionAnswering", "T5ForSequenceClassification", "T5ForTokenClassification", diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index 778fa7046f7d..8bac1b7cb235 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -25,7 +25,6 @@ from ..tokenization_utils_base import PreTrainedTokenizerBase from ..utils import ( TensorType, - is_tf_available, is_torch_available, logging, ) @@ -35,9 +34,6 @@ if is_torch_available(): from ..modeling_utils import PreTrainedModel -if is_tf_available(): - from ..modeling_tf_utils import TFPreTrainedModel - if TYPE_CHECKING: from ..feature_extraction_utils import FeatureExtractionMixin from ..processing_utils import ProcessorMixin @@ -183,75 +179,9 @@ def export_pytorch( return matched_inputs, onnx_outputs -def export_tensorflow( - preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"], - model: "TFPreTrainedModel", - config: OnnxConfig, - opset: int, - output: Path, - tokenizer: Optional["PreTrainedTokenizer"] = None, -) -> tuple[list[str], list[str]]: - """ - Export a TensorFlow model to an ONNX Intermediate Representation (IR) - - Args: - preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]): - The preprocessor used for encoding the data. - model ([`TFPreTrainedModel`]): - The model to export. - config ([`~onnx.config.OnnxConfig`]): - The ONNX configuration associated with the exported model. - opset (`int`): - The version of the ONNX operator set to use. - output (`Path`): - Directory to store the exported ONNX model. - - Returns: - `tuple[list[str], list[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from - the ONNX configuration. - """ - import onnx - import tensorflow as tf - import tf2onnx - - if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None: - raise ValueError("You cannot provide both a tokenizer and preprocessor to export the model.") - if tokenizer is not None: - warnings.warn( - "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use" - " `preprocessor` instead.", - FutureWarning, - ) - logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummy inputs.") - preprocessor = tokenizer - - model.config.return_dict = True - - # Check if we need to override certain configuration item - if config.values_override is not None: - logger.info(f"Overriding {len(config.values_override)} configuration item(s)") - for override_config_key, override_config_value in config.values_override.items(): - logger.info(f"\t- {override_config_key} -> {override_config_value}") - setattr(model.config, override_config_key, override_config_value) - - # Ensure inputs match - model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.TENSORFLOW) - inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) - onnx_outputs = list(config.outputs.keys()) - - input_signature = [ - tf.TensorSpec([None] * tensor.ndim, dtype=tensor.dtype, name=key) for key, tensor in model_inputs.items() - ] - onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset) - onnx.save(onnx_model, output.as_posix()) - config.restore_ops() - - return matched_inputs, onnx_outputs - - def export( preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"], - model: Union["PreTrainedModel", "TFPreTrainedModel"], + model: "PreTrainedModel", config: OnnxConfig, opset: int, output: Path, @@ -264,7 +194,7 @@ def export( Args: preprocessor: ([`PreTrainedTokenizer`], [`FeatureExtractionMixin`] or [`ProcessorMixin`]): The preprocessor used for encoding the data. - model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): + model ([`PreTrainedModel`): The model to export. config ([`~onnx.config.OnnxConfig`]): The ONNX configuration associated with the exported model. @@ -280,14 +210,8 @@ def export( `tuple[list[str], list[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from the ONNX configuration. """ - if not (is_torch_available() or is_tf_available()): - raise ImportError( - "Cannot convert because neither PyTorch nor TensorFlow are not installed. " - "Please install torch or tensorflow first." - ) - - if is_tf_available() and isinstance(model, TFPreTrainedModel) and device == "cuda": - raise RuntimeError("`tf2onnx` does not support export on CUDA device.") + if not is_torch_available(): + raise ImportError("Cannot convert because PyTorchis not installed. Please install it first.") if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None: raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.") @@ -300,25 +224,22 @@ def export( logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummy inputs.") preprocessor = tokenizer - if is_torch_available(): - from ..utils import get_torch_version + from ..utils import get_torch_version - if not config.is_torch_support_available: - logger.warning( - f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version}," - f" got: {get_torch_version()}" - ) + if not config.is_torch_support_available: + logger.warning( + f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version}," + f" got: {get_torch_version()}" + ) - if is_torch_available() and issubclass(type(model), PreTrainedModel): + if issubclass(type(model), PreTrainedModel): return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer, device=device) - elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): - return export_tensorflow(preprocessor, model, config, opset, output, tokenizer=tokenizer) def validate_model_outputs( config: OnnxConfig, preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"], - reference_model: Union["PreTrainedModel", "TFPreTrainedModel"], + reference_model: "PreTrainedModel", onnx_model: Path, onnx_named_outputs: list[str], atol: float, @@ -341,27 +262,20 @@ def validate_model_outputs( # generate inputs with a different batch_size and seq_len that was used for conversion to properly test # dynamic input shapes. - if is_torch_available() and issubclass(type(reference_model), PreTrainedModel): + if issubclass(type(reference_model), PreTrainedModel): reference_model_inputs = config.generate_dummy_inputs( preprocessor, batch_size=config.default_fixed_batch + 1, seq_length=config.default_fixed_sequence + 1, framework=TensorType.PYTORCH, ) - else: - reference_model_inputs = config.generate_dummy_inputs( - preprocessor, - batch_size=config.default_fixed_batch + 1, - seq_length=config.default_fixed_sequence + 1, - framework=TensorType.TENSORFLOW, - ) # Create ONNX Runtime session options = SessionOptions() session = InferenceSession(onnx_model.as_posix(), options, providers=["CPUExecutionProvider"]) # Compute outputs from the reference model - if is_torch_available() and issubclass(type(reference_model), PreTrainedModel): + if issubclass(type(reference_model), PreTrainedModel): reference_model.to("cpu") ref_outputs = reference_model(**reference_model_inputs) ref_outputs_dict = {} @@ -439,16 +353,12 @@ def validate_model_outputs( def ensure_model_and_config_inputs_match( - model: Union["PreTrainedModel", "TFPreTrainedModel"], model_inputs: Iterable[str] + model: "PreTrainedModel", model_inputs: Iterable[str] ) -> tuple[bool, list[str]]: """ - :param model_inputs: :param config_inputs: :return: """ - if is_torch_available() and issubclass(type(model), PreTrainedModel): - forward_parameters = signature(model.forward).parameters - else: - forward_parameters = signature(model.call).parameters + forward_parameters = signature(model.forward).parameters model_inputs_set = set(model_inputs) # We are fine if config_inputs has more keys than model_inputs diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index a24aa79a5968..5a2180798c22 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -1,16 +1,15 @@ -import os from functools import partial, reduce -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional import transformers -from .. import PretrainedConfig, is_tf_available, is_torch_available -from ..utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging +from .. import PretrainedConfig, is_torch_available +from ..utils import logging from .config import OnnxConfig if TYPE_CHECKING: - from transformers import PreTrainedModel, TFPreTrainedModel + from transformers import PreTrainedModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -33,22 +32,9 @@ AutoModelForTokenClassification, AutoModelForVision2Seq, ) -if is_tf_available(): - from transformers.models.auto import ( - TFAutoModel, - TFAutoModelForCausalLM, - TFAutoModelForMaskedLM, - TFAutoModelForMultipleChoice, - TFAutoModelForQuestionAnswering, - TFAutoModelForSemanticSegmentation, - TFAutoModelForSeq2SeqLM, - TFAutoModelForSequenceClassification, - TFAutoModelForTokenClassification, - ) -if not is_torch_available() and not is_tf_available(): +else: logger.warning( - "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models" - " without one of these libraries installed." + "The ONNX export features is only supported for PyTorch. You will not be able to export models without it installed." ) @@ -84,7 +70,6 @@ def supported_features_mapping( class FeaturesManager: _TASKS_TO_AUTOMODELS = {} - _TASKS_TO_TF_AUTOMODELS = {} if is_torch_available(): _TASKS_TO_AUTOMODELS = { "default": AutoModel, @@ -103,18 +88,6 @@ class FeaturesManager: "vision2seq-lm": AutoModelForVision2Seq, "speech2seq-lm": AutoModelForSpeechSeq2Seq, } - if is_tf_available(): - _TASKS_TO_TF_AUTOMODELS = { - "default": TFAutoModel, - "masked-lm": TFAutoModelForMaskedLM, - "causal-lm": TFAutoModelForCausalLM, - "seq2seq-lm": TFAutoModelForSeq2SeqLM, - "sequence-classification": TFAutoModelForSequenceClassification, - "token-classification": TFAutoModelForTokenClassification, - "multiple-choice": TFAutoModelForMultipleChoice, - "question-answering": TFAutoModelForQuestionAnswering, - "semantic-segmentation": TFAutoModelForSemanticSegmentation, - } # Set of model topologies we support associated to the features supported by each topology and the factory _SUPPORTED_MODEL_TYPE = { @@ -584,40 +557,19 @@ def feature_to_task(feature: str) -> str: return feature.replace("-with-past", "") @staticmethod - def _validate_framework_choice(framework: str): - """ - Validates if the framework requested for the export is both correct and available, otherwise throws an - exception. - """ - if framework not in ["pt", "tf"]: - raise ValueError( - f"Only two frameworks are supported for ONNX export: pt or tf, but {framework} was provided." - ) - elif framework == "pt" and not is_torch_available(): - raise RuntimeError("Cannot export model to ONNX using PyTorch because no PyTorch package was found.") - elif framework == "tf" and not is_tf_available(): - raise RuntimeError("Cannot export model to ONNX using TensorFlow because no TensorFlow package was found.") - - @staticmethod - def get_model_class_for_feature(feature: str, framework: str = "pt") -> type: + def get_model_class_for_feature(feature: str) -> type: """ Attempts to retrieve an AutoModel class from a feature name. Args: feature (`str`): The feature required. - framework (`str`, *optional*, defaults to `"pt"`): - The framework to use for the export. Returns: The AutoModel class corresponding to the feature. """ task = FeaturesManager.feature_to_task(feature) - FeaturesManager._validate_framework_choice(framework) - if framework == "pt": - task_to_automodel = FeaturesManager._TASKS_TO_AUTOMODELS - else: - task_to_automodel = FeaturesManager._TASKS_TO_TF_AUTOMODELS + task_to_automodel = FeaturesManager._TASKS_TO_AUTOMODELS if task not in task_to_automodel: raise KeyError( f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}" @@ -626,59 +578,7 @@ def get_model_class_for_feature(feature: str, framework: str = "pt") -> type: return task_to_automodel[task] @staticmethod - def determine_framework(model: str, framework: Optional[str] = None) -> str: - """ - Determines the framework to use for the export. - - The priority is in the following order: - 1. User input via `framework`. - 2. If local checkpoint is provided, use the same framework as the checkpoint. - 3. Available framework in environment, with priority given to PyTorch - - Args: - model (`str`): - The name of the model to export. - framework (`str`, *optional*, defaults to `None`): - The framework to use for the export. See above for priority if none provided. - - Returns: - The framework to use for the export. - - """ - if framework is not None: - return framework - - framework_map = {"pt": "PyTorch", "tf": "TensorFlow"} - exporter_map = {"pt": "torch", "tf": "tf2onnx"} - - if os.path.isdir(model): - if os.path.isfile(os.path.join(model, WEIGHTS_NAME)): - framework = "pt" - elif os.path.isfile(os.path.join(model, TF2_WEIGHTS_NAME)): - framework = "tf" - else: - raise FileNotFoundError( - "Cannot determine framework from given checkpoint location." - f" There should be a {WEIGHTS_NAME} for PyTorch" - f" or {TF2_WEIGHTS_NAME} for TensorFlow." - ) - logger.info(f"Local {framework_map[framework]} model found.") - else: - if is_torch_available(): - framework = "pt" - elif is_tf_available(): - framework = "tf" - else: - raise OSError("Neither PyTorch nor TensorFlow found in environment. Cannot export to ONNX.") - - logger.info(f"Framework not requested. Using {exporter_map[framework]} to export to ONNX.") - - return framework - - @staticmethod - def get_model_from_feature( - feature: str, model: str, framework: Optional[str] = None, cache_dir: Optional[str] = None - ) -> Union["PreTrainedModel", "TFPreTrainedModel"]: + def get_model_from_feature(feature: str, model: str, cache_dir: Optional[str] = None) -> PreTrainedModel: """ Attempts to retrieve a model from a model's name and the feature to be enabled. @@ -687,31 +587,17 @@ def get_model_from_feature( The feature required. model (`str`): The name of the model to export. - framework (`str`, *optional*, defaults to `None`): - The framework to use for the export. See `FeaturesManager.determine_framework` for the priority should - none be provided. Returns: The instance of the model. """ - framework = FeaturesManager.determine_framework(model, framework) - model_class = FeaturesManager.get_model_class_for_feature(feature, framework) - try: - model = model_class.from_pretrained(model, cache_dir=cache_dir) - except OSError: - if framework == "pt": - logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.") - model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir) - else: - logger.info("Loading PyTorch model in TensorFlow before exporting to ONNX.") - model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir) + model_class = FeaturesManager.get_model_class_for_feature(feature) + model = model_class.from_pretrained(model, cache_dir=cache_dir) return model @staticmethod - def check_supported_model_or_raise( - model: Union["PreTrainedModel", "TFPreTrainedModel"], feature: str = "default" - ) -> tuple[str, Callable]: + def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> tuple[str, Callable]: """ Check whether or not the model has the requested features. diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 7cccf60ca3da..30af924611ae 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -42,7 +42,6 @@ is_offline_mode, is_peft_available, is_pyctcdecode_available, - is_tf_available, is_torch_available, logging, ) diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index a30c45285982..6f87fb3aebdf 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -18,7 +18,6 @@ from ..utils import ( ExplicitEnum, add_end_docstrings, - is_tf_available, is_torch_available, is_vision_available, logging, @@ -32,9 +31,6 @@ from ..image_utils import load_image -if is_tf_available(): - from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES - if is_torch_available(): import torch @@ -107,11 +103,7 @@ class ImageClassificationPipeline(Pipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) requires_backends(self, "vision") - self.check_model_type( - TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES - if self.framework == "tf" - else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES - ) + self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES) def _sanitize_parameters(self, top_k=None, function_to_apply=None, timeout=None): preprocess_params = {} @@ -190,9 +182,8 @@ def __call__( def preprocess(self, image, timeout=None): image = load_image(image, timeout=timeout) - model_inputs = self.image_processor(images=image, return_tensors=self.framework) - if self.framework == "pt": - model_inputs = model_inputs.to(self.dtype) + model_inputs = self.image_processor(images=image, return_tensors="pt") + model_inputs = model_inputs.to(self.dtype) return model_inputs def _forward(self, model_inputs): @@ -214,7 +205,7 @@ def postprocess(self, model_outputs, function_to_apply=None, top_k=5): top_k = self.model.config.num_labels outputs = model_outputs["logits"][0] - if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16): + if outputs.dtype in (torch.bfloat16, torch.float16): outputs = outputs.to(torch.float32).numpy() else: outputs = outputs.numpy() diff --git a/src/transformers/pipelines/image_to_text.py b/src/transformers/pipelines/image_to_text.py index 51f9e70cdd61..f77a19603072 100644 --- a/src/transformers/pipelines/image_to_text.py +++ b/src/transformers/pipelines/image_to_text.py @@ -18,7 +18,6 @@ from ..generation import GenerationConfig from ..utils import ( add_end_docstrings, - is_tf_available, is_torch_available, is_vision_available, logging, @@ -32,9 +31,6 @@ from ..image_utils import load_image -if is_tf_available(): - from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES - if is_torch_available(): import torch @@ -84,9 +80,7 @@ class ImageToTextPipeline(Pipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) requires_backends(self, "vision") - self.check_model_type( - TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES - ) + self.check_model_type(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None, timeout=None): forward_params = {} @@ -174,34 +168,30 @@ def preprocess(self, image, prompt=None, timeout=None): model_type = self.model.config.model_type if model_type == "git": - model_inputs = self.image_processor(images=image, return_tensors=self.framework) - if self.framework == "pt": - model_inputs = model_inputs.to(self.dtype) + model_inputs = self.image_processor(images=image, return_tensors="pt") + model_inputs = model_inputs.to(self.dtype) input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids input_ids = [self.tokenizer.cls_token_id] + input_ids input_ids = torch.tensor(input_ids).unsqueeze(0) model_inputs.update({"input_ids": input_ids}) elif model_type == "pix2struct": - model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework) - if self.framework == "pt": - model_inputs = model_inputs.to(self.dtype) + model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors="pt") + model_inputs = model_inputs.to(self.dtype) elif model_type != "vision-encoder-decoder": # vision-encoder-decoder does not support conditional generation - model_inputs = self.image_processor(images=image, return_tensors=self.framework) - if self.framework == "pt": - model_inputs = model_inputs.to(self.dtype) - text_inputs = self.tokenizer(prompt, return_tensors=self.framework) + model_inputs = self.image_processor(images=image, return_tensors="pt") + model_inputs = model_inputs.to(self.dtype) + text_inputs = self.tokenizer(prompt, return_tensors="pt") model_inputs.update(text_inputs) else: raise ValueError(f"Model type {model_type} does not support conditional text generation") else: - model_inputs = self.image_processor(images=image, return_tensors=self.framework) - if self.framework == "pt": - model_inputs = model_inputs.to(self.dtype) + model_inputs = self.image_processor(images=image, return_tensors="pt") + model_inputs = model_inputs.to(self.dtype) if self.model.config.model_type == "git" and prompt is None: model_inputs["input_ids"] = None @@ -222,10 +212,6 @@ def _forward(self, model_inputs, **generate_kwargs): if "generation_config" not in generate_kwargs: generate_kwargs["generation_config"] = self.generation_config - # FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py` - # parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas - # the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name` - # in the `_prepare_model_inputs` method. inputs = model_inputs.pop(self.model.main_input_name) model_outputs = self.model.generate(inputs, **model_inputs, **generate_kwargs) return model_outputs diff --git a/src/transformers/pipelines/text_classification.py b/src/transformers/pipelines/text_classification.py index 6f11f3bc9741..949f1ff498be 100644 --- a/src/transformers/pipelines/text_classification.py +++ b/src/transformers/pipelines/text_classification.py @@ -4,13 +4,10 @@ import numpy as np -from ..utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available +from ..utils import ExplicitEnum, add_end_docstrings, is_torch_available from .base import GenericTensor, Pipeline, build_pipeline_init_args -if is_tf_available(): - from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES - if is_torch_available(): from ..models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES @@ -89,11 +86,7 @@ class TextClassificationPipeline(Pipeline): def __init__(self, **kwargs): super().__init__(**kwargs) - self.check_model_type( - TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES - if self.framework == "tf" - else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES - ) + self.check_model_type(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES) def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, top_k="", **tokenizer_kwargs): # Using "" as default argument because we're going to use `top_k=None` in user code to declare @@ -175,7 +168,7 @@ def __call__( return result def preprocess(self, inputs, **tokenizer_kwargs) -> dict[str, GenericTensor]: - return_tensors = self.framework + return_tensors = "pt" if isinstance(inputs, dict): return self.tokenizer(**inputs, return_tensors=return_tensors, **tokenizer_kwargs) elif isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], list) and len(inputs[0]) == 2: @@ -193,7 +186,7 @@ def preprocess(self, inputs, **tokenizer_kwargs) -> dict[str, GenericTensor]: def _forward(self, model_inputs): # `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported - model_forward = self.model.forward if self.framework == "pt" else self.model.call + model_forward = self.model.forward if "use_cache" in inspect.signature(model_forward).parameters: model_inputs["use_cache"] = False return self.model(**model_inputs) @@ -217,11 +210,8 @@ def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=Tr outputs = model_outputs["logits"][0] - if self.framework == "pt": - # To enable using fp16 and bf16 - outputs = outputs.float().numpy() - else: - outputs = outputs.numpy() + # To enable using fp16 and bf16 + outputs = outputs.float().numpy() if function_to_apply == ClassificationFunction.SIGMOID: scores = sigmoid(outputs) diff --git a/src/transformers/pipelines/zero_shot_image_classification.py b/src/transformers/pipelines/zero_shot_image_classification.py index 6aeb91620306..e43e85879f9d 100644 --- a/src/transformers/pipelines/zero_shot_image_classification.py +++ b/src/transformers/pipelines/zero_shot_image_classification.py @@ -4,7 +4,6 @@ from ..utils import ( add_end_docstrings, - is_tf_available, is_torch_available, is_vision_available, logging, @@ -23,9 +22,6 @@ from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES -if is_tf_available(): - from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES - from ..tf_utils import stable_softmax logger = logging.get_logger(__name__) @@ -73,11 +69,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) requires_backends(self, "vision") - self.check_model_type( - TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES - if self.framework == "tf" - else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES - ) + self.check_model_type(MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES) @overload def __call__( @@ -160,16 +152,15 @@ def preprocess( if tokenizer_kwargs is None: tokenizer_kwargs = {} image = load_image(image, timeout=timeout) - inputs = self.image_processor(images=[image], return_tensors=self.framework) - if self.framework == "pt": - inputs = inputs.to(self.dtype) + inputs = self.image_processor(images=[image], return_tensors="pt") + inputs = inputs.to(self.dtype) inputs["candidate_labels"] = candidate_labels sequences = [hypothesis_template.format(x) for x in candidate_labels] tokenizer_default_kwargs = {"padding": True} if "siglip" in self.model.config.model_type: tokenizer_default_kwargs.update(padding="max_length", max_length=64, truncation=True) tokenizer_default_kwargs.update(tokenizer_kwargs) - text_inputs = self.tokenizer(sequences, return_tensors=self.framework, **tokenizer_default_kwargs) + text_inputs = self.tokenizer(sequences, return_tensors="pt", **tokenizer_default_kwargs) inputs["text_inputs"] = [text_inputs] return inputs @@ -193,21 +184,16 @@ def _forward(self, model_inputs): def postprocess(self, model_outputs): candidate_labels = model_outputs.pop("candidate_labels") logits = model_outputs["logits"][0] - if self.framework == "pt" and "siglip" in self.model.config.model_type: + if "siglip" in self.model.config.model_type: probs = torch.sigmoid(logits).squeeze(-1) scores = probs.tolist() if not isinstance(scores, list): scores = [scores] - elif self.framework == "pt": + else: probs = logits.softmax(dim=-1).squeeze(-1) scores = probs.tolist() if not isinstance(scores, list): scores = [scores] - elif self.framework == "tf": - probs = stable_softmax(logits, axis=-1) - scores = probs.numpy().tolist() - else: - raise ValueError(f"Unsupported framework: {self.framework}") result = [ {"score": score, "label": candidate_label} diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index b87223ad99d0..ee1bb3f4f4ca 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -113,9 +113,6 @@ GGUF_MIN_VERSION, TORCH_FX_REQUIRED_VERSION, TRITON_MIN_VERSION, - USE_JAX, - USE_TF, - USE_TORCH, XLA_FSDPV2_MIN_VERSION, DummyObject, OptionalDependencyNotAvailable, @@ -152,7 +149,6 @@ is_flash_attn_3_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, - is_flax_available, is_flute_available, is_fp_quant_available, is_fsdp_available, @@ -171,7 +167,6 @@ is_jinja_available, is_jumanpp_available, is_kenlm_available, - is_keras_nlp_available, is_kernels_available, is_levenshtein_available, is_libcst_available, @@ -221,10 +216,6 @@ is_spqr_available, is_sudachi_available, is_sudachi_projection_available, - is_tensorflow_probability_available, - is_tensorflow_text_available, - is_tf2onnx_available, - is_tf_available, is_tiktoken_available, is_timm_available, is_tokenizers_available, diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 573dfad08f5b..f873175a5d49 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -64,9 +64,7 @@ from .generic import working_or_temp_dir from .import_utils import ( ENV_VARS_TRUE_VALUES, - _tf_version, _torch_version, - is_tf_available, is_torch_available, is_training_run_on_sagemaker, ) @@ -232,8 +230,6 @@ def http_user_agent(user_agent: Union[dict, str, None] = None) -> str: ua = f"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" if is_torch_available(): ua += f"; torch/{_torch_version}" - if is_tf_available(): - ua += f"; tensorflow/{_tf_version}" if constants.HF_HUB_DISABLE_TELEMETRY: return ua + "; telemetry/off" if is_training_run_on_sagemaker(): @@ -1020,7 +1016,7 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"): data["dataset_name"] = args_as_dict["dataset_name"] elif "task_name" in args_as_dict: # Extract script name from the example_name - script_name = example_name.replace("tf_", "").replace("flax_", "").replace("run_", "") + script_name = example_name.replace("run_", "") script_name = script_name.replace("_no_trainer", "") data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}" diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 2f6dc0b8e714..df06bd05842d 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -99,15 +99,9 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) -USE_TF = os.environ.get("USE_TF", "AUTO").upper() -USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() -USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() - # Try to run a native pytorch job in an environment with TorchXLA installed by setting this value to 0. USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper() -FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper() - # `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it. # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. TORCH_FX_REQUIRED_VERSION = version.parse("1.10") @@ -221,9 +215,6 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _soundfile_available = _is_package_available("soundfile") _spacy_available = _is_package_available("spacy") _sudachipy_available, _sudachipy_version = _is_package_available("sudachipy", return_version=True) -_tensorflow_probability_available = _is_package_available("tensorflow_probability") -_tensorflow_text_available = _is_package_available("tensorflow_text") -_tf2onnx_available = _is_package_available("tf2onnx") _timm_available = _is_package_available("timm") _tokenizers_available = _is_package_available("tokenizers") _torchaudio_available = _is_package_available("torchaudio") @@ -243,60 +234,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _mistral_common_available = _is_package_available("mistral_common") _triton_available, _triton_version = _is_package_available("triton", return_version=True) -_torch_version = "N/A" -_torch_available = False -if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: - _torch_available, _torch_version = _is_package_available("torch", return_version=True) - if _torch_available: - _torch_available = version.parse(_torch_version) >= version.parse("2.1.0") - if not _torch_available: - logger.warning(f"Disabling PyTorch because PyTorch >= 2.1 is required but found {_torch_version}") -else: - logger.info("Disabling PyTorch because USE_TF is set") - _torch_available = False - - -_tf_version = "N/A" -_tf_available = False -if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: - _tf_available = True -else: - if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: - # Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below - # with tensorflow-cpu to make sure it still works! - _tf_available = importlib.util.find_spec("tensorflow") is not None - if _tf_available: - candidates = ( - "tensorflow", - "tensorflow-cpu", - "tensorflow-gpu", - "tf-nightly", - "tf-nightly-cpu", - "tf-nightly-gpu", - "tf-nightly-rocm", - "intel-tensorflow", - "intel-tensorflow-avx512", - "tensorflow-rocm", - "tensorflow-macos", - "tensorflow-aarch64", - ) - _tf_version = None - # For the metadata, we have to look for both tensorflow and tensorflow-cpu - for pkg in candidates: - try: - _tf_version = importlib.metadata.version(pkg) - break - except importlib.metadata.PackageNotFoundError: - pass - _tf_available = _tf_version is not None - if _tf_available: - if version.parse(_tf_version) < version.parse("2"): - logger.info( - f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum." - ) - _tf_available = False - else: - logger.info("Disabling Tensorflow because USE_TORCH is set") +_torch_available, _torch_version = _is_package_available("torch", return_version=True) +if _torch_available: + _torch_available = version.parse(_torch_version) >= version.parse("2.1.0") + if not _torch_available: + logger.warning(f"Disabling PyTorch because PyTorch >= 2.1 is required but found {_torch_version}") _essentia_available = importlib.util.find_spec("essentia") is not None @@ -351,18 +293,6 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _is_ccl_available = False -_flax_available = False -if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: - _flax_available, _flax_version = _is_package_available("flax", return_version=True) - if _flax_available: - _jax_available, _jax_version = _is_package_available("jax", return_version=True) - if _jax_available: - logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") - else: - _flax_available = _jax_available = False - _jax_version = _flax_version = "N/A" - - _torch_xla_available = False if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES: _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True) @@ -761,26 +691,14 @@ def is_bs4_available() -> Union[tuple[bool, str], bool]: return _bs4_available -def is_tf_available() -> bool: - return _tf_available - - def is_coloredlogs_available() -> Union[tuple[bool, str], bool]: return _coloredlogs_available -def is_tf2onnx_available() -> Union[tuple[bool, str], bool]: - return _tf2onnx_available - - def is_onnx_available() -> Union[tuple[bool, str], bool]: return _onnx_available -def is_flax_available() -> bool: - return _flax_available - - def is_flute_available() -> bool: try: return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.4.1" @@ -1446,14 +1364,6 @@ def is_spacy_available() -> Union[tuple[bool, str], bool]: return _spacy_available -def is_tensorflow_text_available() -> Union[tuple[bool, str], bool]: - return is_tf_available() and _tensorflow_text_available - - -def is_keras_nlp_available() -> Union[tuple[bool, str], bool]: - return is_tensorflow_text_available() and _keras_nlp_available - - def is_in_notebook() -> bool: try: # Check if we are running inside Marimo @@ -1478,10 +1388,6 @@ def is_pytorch_quantization_available() -> Union[tuple[bool, str], bool]: return _pytorch_quantization_available -def is_tensorflow_probability_available() -> Union[tuple[bool, str], bool]: - return _tensorflow_probability_available - - def is_pandas_available() -> Union[tuple[bool, str], bool]: return _pandas_available @@ -1572,10 +1478,7 @@ def is_uroman_available() -> Union[tuple[bool, str], bool]: def torch_only_method(fn: Callable) -> Callable: def wrapper(*args, **kwargs): if not _torch_available: - raise ImportError( - "You need to install pytorch to use this method or class, " - "or activate it with environment variables USE_TORCH=1 and USE_TF=0." - ) + raise ImportError("You need to install pytorch to use this method or class") else: return fn(*args, **kwargs) @@ -1771,30 +1674,6 @@ def check_torch_load_is_safe() -> None: Please note that you may need to restart your runtime after installation. """ -# docstyle-ignore -PYTORCH_IMPORT_ERROR_WITH_TF = """ -{0} requires the PyTorch library but it was not found in your environment. -However, we were able to find a TensorFlow installation. TensorFlow classes begin -with "TF", but are otherwise identically named to our PyTorch classes. This -means that the TF equivalent of the class you tried to import would be "TF{0}". -If you want to use TensorFlow, please use TF classes instead! - -If you really do want to use PyTorch please go to -https://pytorch.org/get-started/locally/ and follow the instructions that -match your environment. -""" - -# docstyle-ignore -TF_IMPORT_ERROR_WITH_PYTORCH = """ -{0} requires the TensorFlow library but it was not found in your environment. -However, we were able to find a PyTorch installation. PyTorch classes do not begin -with "TF", but are otherwise identically named to our TF classes. -If you want to use PyTorch, please use those classes instead! - -If you really do want to use TensorFlow, please follow the instructions on the -installation page https://www.tensorflow.org/install that match your environment. -""" - # docstyle-ignore BS4_IMPORT_ERROR = """ {0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: @@ -1816,14 +1695,6 @@ def check_torch_load_is_safe() -> None: """ -# docstyle-ignore -TENSORFLOW_IMPORT_ERROR = """ -{0} requires the TensorFlow library but it was not found in your environment. Check out the instructions on the -installation page: https://www.tensorflow.org/install and follow the ones that match your environment. -Please note that you may need to restart your runtime after installation. -""" - - # docstyle-ignore DETECTRON2_IMPORT_ERROR = """ {0} requires the detectron2 library but it was not found in your environment. Check out the instructions on the @@ -1832,13 +1703,6 @@ def check_torch_load_is_safe() -> None: """ -# docstyle-ignore -FLAX_IMPORT_ERROR = """ -{0} requires the FLAX library but it was not found in your environment. Check out the instructions on the -installation page: https://github.com/google/flax and follow the ones that match your environment. -Please note that you may need to restart your runtime after installation. -""" - # docstyle-ignore FTFY_IMPORT_ERROR = """ {0} requires the ftfy library but it was not found in your environment. Check out the instructions on the @@ -1864,19 +1728,6 @@ def check_torch_load_is_safe() -> None: Please note that you may need to restart your runtime after installation. """ -# docstyle-ignore -TENSORFLOW_PROBABILITY_IMPORT_ERROR = """ -{0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as -explained here: https://github.com/tensorflow/probability. Please note that you may need to restart your runtime after installation. -""" - -# docstyle-ignore -TENSORFLOW_TEXT_IMPORT_ERROR = """ -{0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as -explained here: https://www.tensorflow.org/text/guide/tf_text_intro. -Please note that you may need to restart your runtime after installation. -""" - # docstyle-ignore TORCHAUDIO_IMPORT_ERROR = """ {0} requires the torchaudio library but it was not found in your environment. Please install it and restart your @@ -2071,7 +1922,6 @@ def check_torch_load_is_safe() -> None: ("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)), ("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)), ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)), - ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)), ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)), @@ -2088,9 +1938,6 @@ def check_torch_load_is_safe() -> None: ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)), ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)), - ("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)), - ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), - ("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)), ("timm", (is_timm_available, TIMM_IMPORT_ERROR)), ("torchaudio", (is_torchaudio_available, TORCHAUDIO_IMPORT_ERROR)), ("natten", (is_natten_available, NATTEN_IMPORT_ERROR)), @@ -2109,7 +1956,6 @@ def check_torch_load_is_safe() -> None: ("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)), ("yt_dlp", (is_yt_dlp_available, YT_DLP_IMPORT_ERROR)), ("rich", (is_rich_available, RICH_IMPORT_ERROR)), - ("keras_nlp", (is_keras_nlp_available, KERAS_NLP_IMPORT_ERROR)), ("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)), ("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)), ("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)), @@ -2125,14 +1971,6 @@ def requires_backends(obj, backends): name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ - # Raise an error for users who might not realize that classes without "TF" are torch-only - if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): - raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name)) - - # Raise the inverse error for PyTorch users trying to load TF classes - if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): - raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name)) - failed = [] for backend in backends: if isinstance(backend, Backend): @@ -2464,8 +2302,6 @@ def inner_fn(fun): BASE_FILE_REQUIREMENTS = { - lambda e: "modeling_tf_" in e: ("tf",), - lambda e: "modeling_flax_" in e: ("flax",), lambda e: "modeling_" in e: ("torch",), lambda e: e.startswith("tokenization_") and e.endswith("_fast"): ("tokenizers",), lambda e: e.startswith("image_processing_") and e.endswith("_fast"): ("vision", "torch", "torchvision"), @@ -2536,8 +2372,6 @@ def create_import_structure_from_path(module_path): backend specified. The default backends are defined according to the filename: - If a file is named like `modeling_*.py`, it will have a `torch` backend - - If a file is named like `modeling_tf_*.py`, it will have a `tf` backend - - If a file is named like `modeling_flax_*.py`, it will have a `flax` backend - If a file is named like `tokenization_*_fast.py`, it will have a `tokenizers` backend - If a file is named like `image_processing*_fast.py`, it will have a `torchvision` + `torch` backend @@ -2615,8 +2449,8 @@ def find_substring(substring, list_): previous_index = 0 # Some files have some requirements by default. - # For example, any file named `modeling_tf_xxx.py` - # should have TensorFlow as a required backend. + # For example, any file named `modeling_xxx.py` + # should have torch as a required backend. base_requirements = () for string_check, requirements in BASE_FILE_REQUIREMENTS.items(): if string_check(module_name): @@ -2652,7 +2486,6 @@ def find_substring(substring, list_): # backends=( # "sentencepiece", # "torch", - # "tf", # ) # ) # @@ -2660,7 +2493,7 @@ def find_substring(substring, list_): # # @export( # backends=( - # "sentencepiece", "tf" + # "sentencepiece", # ) # ) elif "backends" in lines[previous_index + 1]: diff --git a/tests/models/byt5/test_tokenization_byt5.py b/tests/models/byt5/test_tokenization_byt5.py index 35dd95424dcc..baadfc67c2b8 100644 --- a/tests/models/byt5/test_tokenization_byt5.py +++ b/tests/models/byt5/test_tokenization_byt5.py @@ -21,19 +21,10 @@ from functools import cached_property from transformers import AddedToken, BatchEncoding, ByT5Tokenizer -from transformers.utils import is_tf_available, is_torch_available from ...test_tokenization_common import TokenizerTesterMixin -if is_torch_available(): - FRAMEWORK = "pt" -elif is_tf_available(): - FRAMEWORK = "tf" -else: - FRAMEWORK = "jax" - - class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = ByT5Tokenizer test_rust_tokenizer = False @@ -122,13 +113,10 @@ def test_prepare_batch_integration(self): tokenizer = self.t5_base_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] expected_src_tokens = [68, 35, 111, 114, 113, 106, 35, 115, 100, 117, 100, 106, 117, 100, 115, 107, 35, 105, 114, 117, 35, 118, 120, 112, 112, 100, 117, 108, 125, 100, 119, 108, 114, 113, 49, 1, 0] # fmt: skip - batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + batch = tokenizer(src_text, padding=True, return_tensors="pt") self.assertIsInstance(batch, BatchEncoding) - if FRAMEWORK != "jax": - result = list(batch.input_ids.numpy()[0]) - else: - result = list(batch.input_ids.tolist()[0]) + result = list(batch.input_ids.numpy()[0]) self.assertListEqual(expected_src_tokens, result) @@ -138,7 +126,7 @@ def test_prepare_batch_integration(self): def test_empty_target_text(self): tokenizer = self.t5_base_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] - batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + batch = tokenizer(src_text, padding=True, return_tensors="pt") # check if input_ids are returned and no decoder_input_ids self.assertIn("input_ids", batch) self.assertIn("attention_mask", batch) @@ -152,7 +140,7 @@ def test_max_length_integration(self): "Another summary.", ] targets = tokenizer( - text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK + text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt" ) self.assertEqual(32, targets["input_ids"].shape[1]) diff --git a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py index a96f3b5b6aad..5a2a008db02f 100644 --- a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py @@ -25,9 +25,7 @@ AddedToken, LayoutLMv2TokenizerFast, SpecialTokensMixin, - is_flax_available, is_mlx_available, - is_tf_available, is_torch_available, logging, ) @@ -1720,12 +1718,7 @@ def test_batch_encode_dynamic_overflowing(self): tokenizer = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name}, {tokenizer.__class__.__name__})"): - if is_torch_available(): - returned_tensor = "pt" - elif is_tf_available(): - returned_tensor = "tf" - else: - returned_tensor = "jax" + returned_tensor = "pt" # Single example words, boxes = self.get_words_and_boxes() @@ -2434,18 +2427,6 @@ def test_empty_input_string(self): tokenizer_return_type.append("np") output_tensor_type.append(np.int64) - if is_tf_available(): - import tensorflow as tf - - tokenizer_return_type.append("tf") - output_tensor_type.append(tf.int32) - - if is_flax_available(): - import jax.numpy as jnp - - tokenizer_return_type.append("jax") - output_tensor_type.append(jnp.int32) - if is_mlx_available(): import mlx.core as mx @@ -2453,7 +2434,7 @@ def test_empty_input_string(self): output_tensor_type.append(mx.int32) if len(tokenizer_return_type) == 0: - self.skipTest(reason="No expected framework from PT, TF, JAX or MLX found") + self.skipTest(reason="No expected framework from PT or MLX found") tokenizers = self.get_tokenizers() for tokenizer in tokenizers: diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py index c487e662bf9a..1568a7e01104 100644 --- a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py @@ -26,9 +26,7 @@ AddedToken, LayoutLMv3TokenizerFast, SpecialTokensMixin, - is_flax_available, is_mlx_available, - is_tf_available, is_torch_available, logging, ) @@ -1605,12 +1603,7 @@ def test_batch_encode_dynamic_overflowing(self): tokenizer = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name}, {tokenizer.__class__.__name__})"): - if is_torch_available(): - returned_tensor = "pt" - elif is_tf_available(): - returned_tensor = "tf" - else: - returned_tensor = "jax" + returned_tensor = "pt" # Single example words = ["HuggingFace", "is", "solving", "NLP", "one", "commit", "at", "a", "time"] @@ -2358,18 +2351,6 @@ def test_empty_input_string(self): tokenizer_return_type.append("np") output_tensor_type.append(np.int64) - if is_tf_available(): - import tensorflow as tf - - tokenizer_return_type.append("tf") - output_tensor_type.append(tf.int32) - - if is_flax_available(): - import jax.numpy as jnp - - tokenizer_return_type.append("jax") - output_tensor_type.append(jnp.int32) - if is_mlx_available(): import mlx.core as mx @@ -2377,7 +2358,7 @@ def test_empty_input_string(self): output_tensor_type.append(mx.int32) if len(tokenizer_return_type) == 0: - self.skipTest(reason="No expected framework from PT, TF, JAX or MLX found") + self.skipTest(reason="No expected framework from PT, or MLX found") tokenizers = self.get_tokenizers() for tokenizer in tokenizers: diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py index 21bd0a469dd4..5c09a14ecaa6 100644 --- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py +++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py @@ -23,9 +23,7 @@ AddedToken, LayoutXLMTokenizerFast, SpecialTokensMixin, - is_flax_available, is_mlx_available, - is_tf_available, is_torch_available, logging, ) @@ -1649,12 +1647,7 @@ def test_batch_encode_dynamic_overflowing(self): tokenizer = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name}, {tokenizer.__class__.__name__})"): - if is_torch_available(): - returned_tensor = "pt" - elif is_tf_available(): - returned_tensor = "tf" - else: - returned_tensor = "jax" + returned_tensor = "pt" # Single example words, boxes = self.get_words_and_boxes() @@ -1928,18 +1921,6 @@ def test_empty_input_string(self): tokenizer_return_type.append("np") output_tensor_type.append(np.int64) - if is_tf_available(): - import tensorflow as tf - - tokenizer_return_type.append("tf") - output_tensor_type.append(tf.int32) - - if is_flax_available(): - import jax.numpy as jnp - - tokenizer_return_type.append("jax") - output_tensor_type.append(jnp.int32) - if is_mlx_available(): import mlx.core as mx @@ -1947,7 +1928,7 @@ def test_empty_input_string(self): output_tensor_type.append(mx.int32) if len(tokenizer_return_type) == 0: - self.skipTest(reason="No expected framework from PT, TF, JAX or MLX found") + self.skipTest(reason="No expected framework from PT or MLX found") tokenizers = self.get_tokenizers() for tokenizer in tokenizers: diff --git a/tests/models/marian/test_tokenization_marian.py b/tests/models/marian/test_tokenization_marian.py index ebe26c5babb7..fbcd94c29f85 100644 --- a/tests/models/marian/test_tokenization_marian.py +++ b/tests/models/marian/test_tokenization_marian.py @@ -19,7 +19,7 @@ from transformers import BatchEncoding, MarianTokenizer from transformers.testing_utils import get_tests_dir, require_sentencepiece, slow -from transformers.utils import is_sentencepiece_available, is_tf_available, is_torch_available +from transformers.utils import is_sentencepiece_available if is_sentencepiece_available(): @@ -34,13 +34,6 @@ zh_code = ">>zh<<" ORG_NAME = "Helsinki-NLP/" -if is_torch_available(): - FRAMEWORK = "pt" -elif is_tf_available(): - FRAMEWORK = "tf" -else: - FRAMEWORK = "jax" - @require_sentencepiece class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): @@ -112,14 +105,14 @@ def test_outputs_not_longer_than_maxlen(self): tok = self.get_tokenizer() batch = tok( - ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK + ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors="pt" ) self.assertIsInstance(batch, BatchEncoding) self.assertEqual(batch.input_ids.shape, (2, 512)) def test_outputs_can_be_shorter(self): tok = self.get_tokenizer() - batch_smaller = tok(["I am a tiny frog", "I am a small frog"], padding=True, return_tensors=FRAMEWORK) + batch_smaller = tok(["I am a tiny frog", "I am a small frog"], padding=True, return_tensors="pt") self.assertIsInstance(batch_smaller, BatchEncoding) self.assertEqual(batch_smaller.input_ids.shape, (2, 10)) diff --git a/tests/models/markuplm/test_tokenization_markuplm.py b/tests/models/markuplm/test_tokenization_markuplm.py index ee08dbf51b01..0d5ef0efdb02 100644 --- a/tests/models/markuplm/test_tokenization_markuplm.py +++ b/tests/models/markuplm/test_tokenization_markuplm.py @@ -26,9 +26,7 @@ AddedToken, MarkupLMTokenizerFast, SpecialTokensMixin, - is_flax_available, is_mlx_available, - is_tf_available, is_torch_available, logging, ) @@ -1505,12 +1503,7 @@ def test_batch_encode_dynamic_overflowing(self): tokenizer = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name}, {tokenizer.__class__.__name__})"): - if is_torch_available(): - returned_tensor = "pt" - elif is_tf_available(): - returned_tensor = "tf" - else: - returned_tensor = "jax" + returned_tensor = "pt" # Single example nodes, xpaths = self.get_nodes_and_xpaths() @@ -2276,18 +2269,6 @@ def test_empty_input_string(self): tokenizer_return_type.append("np") output_tensor_type.append(np.int64) - if is_tf_available(): - import tensorflow as tf - - tokenizer_return_type.append("tf") - output_tensor_type.append(tf.int32) - - if is_flax_available(): - import jax.numpy as jnp - - tokenizer_return_type.append("jax") - output_tensor_type.append(jnp.int32) - if is_mlx_available(): import mlx.core as mx @@ -2295,7 +2276,7 @@ def test_empty_input_string(self): output_tensor_type.append(mx.int32) if len(tokenizer_return_type) == 0: - self.skipTest(reason="No expected framework from PT, TF, JAX or MLX found") + self.skipTest(reason="No expected framework from PT, or MLX found") tokenizers = self.get_tokenizers() for tokenizer in tokenizers: diff --git a/tests/models/myt5/test_tokenization_myt5.py b/tests/models/myt5/test_tokenization_myt5.py index f36b6c7ec56c..5f57aa051b71 100644 --- a/tests/models/myt5/test_tokenization_myt5.py +++ b/tests/models/myt5/test_tokenization_myt5.py @@ -16,19 +16,10 @@ from transformers import MyT5Tokenizer from transformers.testing_utils import slow -from transformers.utils import is_tf_available, is_torch_available from ...test_tokenization_common import TokenizerTesterMixin -if is_torch_available(): - FRAMEWORK = "pt" -elif is_tf_available(): - FRAMEWORK = "tf" -else: - FRAMEWORK = "jax" - - def bytes_to_hex(bline: bytes, sep: str = " ") -> str: return str(binascii.hexlify(bline, sep), "utf-8") diff --git a/tests/models/pegasus/test_tokenization_pegasus.py b/tests/models/pegasus/test_tokenization_pegasus.py index bd8d07cdabae..ce65a1065c9f 100644 --- a/tests/models/pegasus/test_tokenization_pegasus.py +++ b/tests/models/pegasus/test_tokenization_pegasus.py @@ -187,25 +187,6 @@ def test_large_seq2seq_truncation(self): assert len(batch) == 2 # input_ids, attention_mask. def test_equivalence_to_orig_tokenizer(self): - """ - To run with original TF tokenizer: - - !wget https://github.com/google-research/bigbird/raw/master/bigbird/vocab/pegasus.model - !pip install tensorflow-text - - import tensorflow.compat.v2 as tf - import tensorflow_text as tft - - VOCAB_FILE = "./pegasus.model" - - tf.enable_v2_behavior() - - test_str = "This is an example string that is used to test the original TF implementation against the HF implementation" - tokenizer = tft.SentencepieceTokenizer(model=tf.io.gfile.GFile(VOCAB_FILE, "rb").read()) - - tokenizer.tokenize(test_str) - """ - test_str = ( "This is an example string that is used to test the original TF implementation against the HF" " implementation" diff --git a/tests/models/perceiver/test_tokenization_perceiver.py b/tests/models/perceiver/test_tokenization_perceiver.py index 87d9f3d0075c..7dc0ae6e9867 100644 --- a/tests/models/perceiver/test_tokenization_perceiver.py +++ b/tests/models/perceiver/test_tokenization_perceiver.py @@ -21,19 +21,10 @@ from functools import cached_property from transformers import AddedToken, BatchEncoding, PerceiverTokenizer -from transformers.utils import is_tf_available, is_torch_available from ...test_tokenization_common import TokenizerTesterMixin -if is_torch_available(): - FRAMEWORK = "pt" -elif is_tf_available(): - FRAMEWORK = "tf" -else: - FRAMEWORK = "jax" - - class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase): from_pretrained_id = "deepmind/language-perceiver" tokenizer_class = PerceiverTokenizer @@ -117,13 +108,10 @@ def test_prepare_batch_integration(self): tokenizer = self.perceiver_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] expected_src_tokens = [4, 71, 38, 114, 117, 116, 109, 38, 118, 103, 120, 103, 109, 120, 103, 118, 110, 38, 108, 117, 120, 38, 121, 123, 115, 115, 103, 120, 111, 128, 103, 122, 111, 117, 116, 52, 5, 0] # fmt: skip - batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + batch = tokenizer(src_text, padding=True, return_tensors="pt") self.assertIsInstance(batch, BatchEncoding) - if FRAMEWORK != "jax": - result = list(batch.input_ids.numpy()[0]) - else: - result = list(batch.input_ids.tolist()[0]) + result = list(batch.input_ids.numpy()[0]) self.assertListEqual(expected_src_tokens, result) @@ -133,7 +121,7 @@ def test_prepare_batch_integration(self): def test_empty_target_text(self): tokenizer = self.perceiver_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] - batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + batch = tokenizer(src_text, padding=True, return_tensors="pt") # check if input_ids are returned and no decoder_input_ids self.assertIn("input_ids", batch) self.assertIn("attention_mask", batch) @@ -147,7 +135,7 @@ def test_max_length_integration(self): "Another summary.", ] targets = tokenizer( - text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK + text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt" ) self.assertEqual(32, targets["input_ids"].shape[1]) diff --git a/tests/models/sam2/test_processor_sam2.py b/tests/models/sam2/test_processor_sam2.py index d0b099c77698..1c388e210836 100644 --- a/tests/models/sam2/test_processor_sam2.py +++ b/tests/models/sam2/test_processor_sam2.py @@ -22,7 +22,7 @@ require_torchvision, require_vision, ) -from transformers.utils import is_tf_available, is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_vision_available if is_vision_available(): @@ -31,9 +31,6 @@ if is_torch_available(): import torch -if is_tf_available(): - pass - @require_vision @require_torchvision diff --git a/tests/models/sam2_video/test_processor_sam2_video.py b/tests/models/sam2_video/test_processor_sam2_video.py index 0e359e716b9d..6e071158be11 100644 --- a/tests/models/sam2_video/test_processor_sam2_video.py +++ b/tests/models/sam2_video/test_processor_sam2_video.py @@ -22,7 +22,7 @@ require_torchvision, require_vision, ) -from transformers.utils import is_tf_available, is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_vision_available if is_vision_available(): @@ -31,9 +31,6 @@ if is_torch_available(): import torch -if is_tf_available(): - pass - @require_vision @require_torchvision diff --git a/tests/models/siglip/test_tokenization_siglip.py b/tests/models/siglip/test_tokenization_siglip.py index 68e7f1fbf4b1..6d54d7dad26a 100644 --- a/tests/models/siglip/test_tokenization_siglip.py +++ b/tests/models/siglip/test_tokenization_siglip.py @@ -20,20 +20,17 @@ from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, SiglipTokenizer from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow +<<<<<<< HEAD from transformers.utils import is_tf_available, is_torch_available +======= +from transformers.utils import cached_property +>>>>>>> 928d4415fa (more and more) from ...test_tokenization_common import TokenizerTesterMixin SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") -if is_torch_available(): - FRAMEWORK = "pt" -elif is_tf_available(): - FRAMEWORK = "tf" -else: - FRAMEWORK = "jax" - @require_sentencepiece @require_tokenizers @@ -173,13 +170,10 @@ def test_prepare_batch(self): tokenizer = self.siglip_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] expected_src_tokens = [262, 266, 476, 8532, 270, 4460, 3949, 1682, tokenizer.eos_token_id] - batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + batch = tokenizer(src_text, padding=True, return_tensors="pt") self.assertIsInstance(batch, BatchEncoding) - if FRAMEWORK != "jax": - result = list(batch.input_ids.numpy()[0]) - else: - result = list(batch.input_ids.tolist()[0]) + result = list(batch.input_ids.numpy()[0]) self.assertListEqual(expected_src_tokens, result) @@ -188,7 +182,7 @@ def test_prepare_batch(self): def test_empty_target_text(self): tokenizer = self.siglip_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] - batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + batch = tokenizer(src_text, padding=True, return_tensors="pt") # check if input_ids are returned and no decoder_input_ids self.assertIn("input_ids", batch) self.assertNotIn("decoder_input_ids", batch) @@ -198,7 +192,7 @@ def test_max_length(self): tokenizer = self.siglip_tokenizer tgt_text = ["Summary of the text.", "Another summary."] targets = tokenizer( - text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK + text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt" ) self.assertEqual(32, targets["input_ids"].shape[1]) diff --git a/tests/models/splinter/test_tokenization_splinter.py b/tests/models/splinter/test_tokenization_splinter.py index 8fdf5b75b00e..c87d3590d7d2 100644 --- a/tests/models/splinter/test_tokenization_splinter.py +++ b/tests/models/splinter/test_tokenization_splinter.py @@ -14,7 +14,7 @@ import unittest from tests.test_tokenization_common import TokenizerTesterMixin -from transformers import SplinterTokenizerFast, is_tf_available, is_torch_available +from transformers import SplinterTokenizerFast from transformers.models.splinter import SplinterTokenizer from transformers.testing_utils import get_tests_dir, slow @@ -22,14 +22,6 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/vocab.txt") -if is_torch_available(): - FRAMEWORK = "pt" -elif is_tf_available(): - FRAMEWORK = "tf" -else: - FRAMEWORK = "jax" - - class SplinterTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = SplinterTokenizer rust_tokenizer_class = SplinterTokenizerFast @@ -128,7 +120,7 @@ def test_max_length(self): max_length=max_length, padding="max_length", truncation=True, - return_tensors=FRAMEWORK, + return_tensors="pt", ) self.assertEqual(len(tokenized["input_ids"]), len(texts)) self.assertEqual(len(tokenized["input_ids"][0]), max_length) diff --git a/tests/models/t5/test_tokenization_t5.py b/tests/models/t5/test_tokenization_t5.py index cfc689eaf1c0..cbbe6e2c916e 100644 --- a/tests/models/t5/test_tokenization_t5.py +++ b/tests/models/t5/test_tokenization_t5.py @@ -20,20 +20,12 @@ from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, T5Tokenizer, T5TokenizerFast from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_seqio, require_tokenizers, slow -from transformers.utils import is_tf_available, is_torch_available from ...test_tokenization_common import TokenizerTesterMixin SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") -if is_torch_available(): - FRAMEWORK = "pt" -elif is_tf_available(): - FRAMEWORK = "tf" -else: - FRAMEWORK = "jax" - @require_sentencepiece @require_tokenizers @@ -188,13 +180,10 @@ def test_prepare_batch(self): tokenizer = self.t5_base_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id] - batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + batch = tokenizer(src_text, padding=True, return_tensors="pt") self.assertIsInstance(batch, BatchEncoding) - if FRAMEWORK != "jax": - result = list(batch.input_ids.numpy()[0]) - else: - result = list(batch.input_ids.tolist()[0]) + result = list(batch.input_ids.numpy()[0]) self.assertListEqual(expected_src_tokens, result) @@ -204,7 +193,7 @@ def test_prepare_batch(self): def test_empty_target_text(self): tokenizer = self.t5_base_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] - batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + batch = tokenizer(src_text, padding=True, return_tensors="pt") # check if input_ids are returned and no decoder_input_ids self.assertIn("input_ids", batch) self.assertIn("attention_mask", batch) @@ -218,7 +207,7 @@ def test_max_length(self): "Another summary.", ] targets = tokenizer( - text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK + text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt" ) self.assertEqual(32, targets["input_ids"].shape[1]) @@ -226,7 +215,7 @@ def test_outputs_not_longer_than_maxlen(self): tokenizer = self.t5_base_tokenizer batch = tokenizer( - ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK + ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors="pt" ) self.assertIsInstance(batch, BatchEncoding) # Since T5 does NOT have a max input length, diff --git a/tests/models/udop/test_tokenization_udop.py b/tests/models/udop/test_tokenization_udop.py index 3a05d98bfc0a..f5270ee84d2c 100644 --- a/tests/models/udop/test_tokenization_udop.py +++ b/tests/models/udop/test_tokenization_udop.py @@ -22,9 +22,7 @@ SpecialTokensMixin, UdopTokenizer, UdopTokenizerFast, - is_flax_available, is_mlx_available, - is_tf_available, is_torch_available, logging, ) @@ -1593,12 +1591,7 @@ def test_batch_encode_dynamic_overflowing(self): tokenizer = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name}, {tokenizer.__class__.__name__})"): - if is_torch_available(): - returned_tensor = "pt" - elif is_tf_available(): - returned_tensor = "tf" - else: - returned_tensor = "jax" + returned_tensor = "pt" # Single example words, boxes = self.get_words_and_boxes() @@ -1897,18 +1890,6 @@ def test_empty_input_string(self): tokenizer_return_type.append("np") output_tensor_type.append(np.int64) - if is_tf_available(): - import tensorflow as tf - - tokenizer_return_type.append("tf") - output_tensor_type.append(tf.int32) - - if is_flax_available(): - import jax.numpy as jnp - - tokenizer_return_type.append("jax") - output_tensor_type.append(jnp.int32) - if is_mlx_available(): import mlx.core as mx @@ -1916,7 +1897,7 @@ def test_empty_input_string(self): output_tensor_type.append(mx.int32) if len(tokenizer_return_type) == 0: - self.skipTest(reason="No expected framework from PT, TF, JAX or MLX found") + self.skipTest(reason="No expected framework from PT, or MLX found") tokenizers = self.get_tokenizers() for tokenizer in tokenizers: diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 8fb87136a024..2ef189cd2956 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -21,7 +21,6 @@ from pathlib import Path import datasets -import numpy as np from huggingface_hub import HfFolder, Repository, delete_repo from requests.exceptions import HTTPError @@ -34,7 +33,6 @@ T5ForConditionalGeneration, TextClassificationPipeline, TextGenerationPipeline, - TFAutoModelForSequenceClassification, pipeline, ) from transformers.pipelines import PIPELINE_REGISTRY, get_task @@ -51,11 +49,10 @@ require_torch, require_torch_accelerator, require_torch_multi_accelerator, - require_torch_or_tf, slow, torch_device, ) -from transformers.utils import direct_transformers_import, is_tf_available, is_torch_available +from transformers.utils import direct_transformers_import, is_torch_available from transformers.utils import logging as transformers_logging @@ -699,14 +696,6 @@ def check_models_equal_pt(self, model1, model2): return models_are_equal - def check_models_equal_tf(self, model1, model2): - models_are_equal = True - for model1_p, model2_p in zip(model1.weights, model2.weights): - if np.abs(model1_p.numpy() - model2_p.numpy()).sum() > 1e-5: - models_are_equal = False - - return models_are_equal - class CustomPipeline(Pipeline): def _sanitize_parameters(self, **kwargs): @@ -751,31 +740,26 @@ def test_register_pipeline(self): "custom-text-classification", pipeline_class=PairClassificationPipeline, pt_model=AutoModelForSequenceClassification if is_torch_available() else None, - tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None, - default={"pt": ("hf-internal-testing/tiny-random-distilbert", "2ef615d")}, + default={"model": ("hf-internal-testing/tiny-random-distilbert", "2ef615d")}, type="text", ) assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks() _, task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification") self.assertEqual(task_def["pt"], (AutoModelForSequenceClassification,) if is_torch_available() else ()) - self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ()) self.assertEqual(task_def["type"], "text") self.assertEqual(task_def["impl"], PairClassificationPipeline) - self.assertEqual( - task_def["default"], {"model": {"pt": ("hf-internal-testing/tiny-random-distilbert", "2ef615d")}} - ) + self.assertEqual(task_def["default"], {"model": ("hf-internal-testing/tiny-random-distilbert", "2ef615d")}) # Clean registry for next tests. del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"] - @require_torch_or_tf + @require_torch def test_dynamic_pipeline(self): PIPELINE_REGISTRY.register_pipeline( "pair-classification", pipeline_class=PairClassificationPipeline, pt_model=AutoModelForSequenceClassification if is_torch_available() else None, - tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None, ) classifier = pipeline("pair-classification", model="hf-internal-testing/tiny-random-bert") @@ -792,7 +776,6 @@ def test_dynamic_pipeline(self): "pair-classification": { "impl": "custom_pipeline.PairClassificationPipeline", "pt": ("AutoModelForSequenceClassification",) if is_torch_available() else (), - "tf": ("TFAutoModelForSequenceClassification",) if is_tf_available() else (), } }, ) @@ -821,7 +804,7 @@ def test_dynamic_pipeline(self): [{"label": "LABEL_0", "score": 0.505}], ) - @require_torch_or_tf + @require_torch def test_cached_pipeline_has_minimum_calls_to_head(self): # Make sure we have cached the pipeline. _ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert") @@ -943,7 +926,6 @@ def test_push_to_hub_dynamic_pipeline(self): "pair-classification": { "impl": "custom_pipeline.PairClassificationPipeline", "pt": ("AutoModelForSequenceClassification",), - "tf": (), } }, ) @@ -966,7 +948,6 @@ def test_push_to_hub_dynamic_pipeline(self): "pair-classification": { "impl": f"{USER}/test-dynamic-pipeline--custom_pipeline.PairClassificationPipeline", "pt": ("AutoModelForSequenceClassification",), - "tf": (), } }, ) diff --git a/tests/pipelines/test_pipelines_mask_generation.py b/tests/pipelines/test_pipelines_mask_generation.py index 3fd8f7b5c44a..ff194764b44b 100644 --- a/tests/pipelines/test_pipelines_mask_generation.py +++ b/tests/pipelines/test_pipelines_mask_generation.py @@ -19,7 +19,6 @@ from transformers import ( MODEL_FOR_MASK_GENERATION_MAPPING, - is_tf_available, is_torch_available, is_vision_available, pipeline, @@ -35,11 +34,6 @@ ) -if is_tf_available(): - from transformers import TF_MODEL_FOR_MASK_GENERATION_MAPPING -else: - TF_MODEL_FOR_MASK_GENERATION_MAPPING = None - if is_torch_available(): from transformers import MODEL_FOR_MASK_GENERATION_MAPPING else: @@ -72,9 +66,6 @@ def mask_to_test_readable(mask: Image) -> dict: @require_torch class MaskGenerationPipelineTests(unittest.TestCase): model_mapping = dict(list(MODEL_FOR_MASK_GENERATION_MAPPING.items()) if MODEL_FOR_MASK_GENERATION_MAPPING else []) - tf_model_mapping = dict( - list(TF_MODEL_FOR_MASK_GENERATION_MAPPING.items()) if TF_MODEL_FOR_MASK_GENERATION_MAPPING else [] - ) def get_test_pipeline( self, diff --git a/utils/check_repo.py b/utils/check_repo.py index e932e5bfc24c..6f95f2662d55 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -33,15 +33,13 @@ import os import re -import sys import types import warnings from collections import OrderedDict from difflib import get_close_matches -from importlib.machinery import ModuleSpec from pathlib import Path -from transformers import is_flax_available, is_tf_available, is_torch_available +from transformers import is_torch_available from transformers.models.auto.auto_factory import get_values from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES @@ -74,7 +72,6 @@ "Qwen3VLVisionModel", "Qwen3VLMoeVisionModel", "SwitchTransformersStack", - "TFDPRSpanPredictor", "MaskFormerSwinModel", "MaskFormerSwinPreTrainedModel", "BridgeTowerTextModel", @@ -130,17 +127,10 @@ "RealmScorer", # Not regular model. "RealmForOpenQA", # Not regular model. "ReformerForMaskedLM", # Needs to be setup as decoder. - "TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?) - "TFRobertaForMultipleChoice", # TODO: fix - "TFRobertaPreLayerNormForMultipleChoice", # TODO: fix "SeparableConv1D", # Building part of bigger (tested) model. - "FlaxBartForCausalLM", # Building part of bigger (tested) model. - "FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM. "OPTDecoderWrapper", - "TFSegformerDecodeHead", # Not a regular model. "AltRobertaModel", # Building part of bigger (tested) model. "BlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models - "TFBlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models "BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model. "BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model. "BarkCausalModel", # Building part of bigger (tested) model. @@ -189,19 +179,12 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ "models/decision_transformer/test_modeling_decision_transformer.py", "models/camembert/test_modeling_camembert.py", - "models/mt5/test_modeling_flax_mt5.py", "models/mbart/test_modeling_mbart.py", "models/mt5/test_modeling_mt5.py", "models/pegasus/test_modeling_pegasus.py", - "models/camembert/test_modeling_tf_camembert.py", - "models/mt5/test_modeling_tf_mt5.py", - "models/xlm_roberta/test_modeling_tf_xlm_roberta.py", - "models/xlm_roberta/test_modeling_flax_xlm_roberta.py", "models/xlm_prophetnet/test_modeling_xlm_prophetnet.py", "models/xlm_roberta/test_modeling_xlm_roberta.py", "models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py", - "models/vision_text_dual_encoder/test_modeling_tf_vision_text_dual_encoder.py", - "models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py", "models/decision_transformer/test_modeling_decision_transformer.py", "models/bark/test_modeling_bark.py", "models/shieldgemma2/test_modeling_shieldgemma2.py", @@ -236,12 +219,6 @@ "BlipTextModel", "BrosSpadeEEForTokenClassification", "BrosSpadeELForTokenClassification", - "TFBlipForConditionalGeneration", - "TFBlipForImageTextRetrieval", - "TFBlipForQuestionAnswering", - "TFBlipVisionModel", - "TFBlipTextLMHeadModel", - "TFBlipTextModel", "Swin2SRForImageSuperResolution", "BridgeTowerForImageAndTextRetrieval", "BridgeTowerForMaskedLM", @@ -272,8 +249,6 @@ "PerceiverForMultimodalAutoencoding", "PerceiverForOpticalFlow", "SegformerDecodeHead", - "TFSegformerDecodeHead", - "FlaxBeitForMaskedImageModeling", "BeitForMaskedImageModeling", "ChineseCLIPTextModel", "ChineseCLIPVisionModel", @@ -283,14 +258,6 @@ "ClvpModel", "GroupViTTextModel", "GroupViTVisionModel", - "TFCLIPTextModel", - "TFCLIPVisionModel", - "TFGroupViTTextModel", - "TFGroupViTVisionModel", - "FlaxCLIPTextModel", - "FlaxCLIPTextModelWithProjection", - "FlaxCLIPVisionModel", - "FlaxWav2Vec2ForCTC", "DetrForSegmentation", "Pix2StructVisionModel", "Pix2StructTextModel", @@ -328,13 +295,6 @@ "RealmForOpenQA", "RealmScorer", "RealmReader", - "TFDPRReader", - "TFGPT2DoubleHeadsModel", - "TFLayoutLMForQuestionAnswering", - "TFOpenAIGPTDoubleHeadsModel", - "TFRagModel", - "TFRagSequenceForGeneration", - "TFRagTokenForGeneration", "Wav2Vec2ForCTC", "HubertForCTC", "SEWForCTC", @@ -346,8 +306,6 @@ "VisualBertForVisualReasoning", "VisualBertForQuestionAnswering", "VisualBertForMultipleChoice", - "TFWav2Vec2ForCTC", - "TFHubertForCTC", "XCLIPVisionModel", "XCLIPTextModel", "AltCLIPTextModel", @@ -405,20 +363,6 @@ "Florence2VisionBackbone", # Building part of a bigger model ] -# DO NOT edit this list! -# (The corresponding pytorch objects should never have been in the main `__init__`, but it's too late to remove) -OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [ - "FlaxBertLayer", - "FlaxBigBirdLayer", - "FlaxRoFormerLayer", - "TFBertLayer", - "TFLxmertEncoder", - "TFLxmertXLayer", - "TFMPNetLayer", - "TFMobileBertLayer", - "TFSegformerLayer", - "TFViTMAELayer", -] # Update this list for models that have multiple model types for the same model doc. MODEL_TYPE_TO_DOC_MAPPING = OrderedDict( @@ -446,10 +390,7 @@ def check_missing_backends(): missing_backends = [] if not is_torch_available(): missing_backends.append("PyTorch") - if not is_tf_available(): - missing_backends.append("TensorFlow") - if not is_flax_available(): - missing_backends.append("Flax") + if len(missing_backends) > 0: missing = ", ".join(missing_backends) if os.getenv("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: @@ -506,15 +447,8 @@ def get_model_modules() -> list[str]: "modeling_encoder_decoder", "modeling_marian", "modeling_retribert", - "modeling_flax_auto", - "modeling_flax_encoder_decoder", "modeling_speech_encoder_decoder", - "modeling_flax_speech_encoder_decoder", - "modeling_flax_vision_encoder_decoder", "modeling_timm_backbone", - "modeling_tf_auto", - "modeling_tf_encoder_decoder", - "modeling_tf_vision_encoder_decoder", "modeling_vision_encoder_decoder", ] modules = [] @@ -545,12 +479,15 @@ def get_models(module: types.ModuleType, include_pretrained: bool = False) -> li List[Tuple[str, type]]: List of models as tuples (class name, actual class). """ models = [] - model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel) for attr_name in dir(module): if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name): continue attr = getattr(module, attr_name) - if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__: + if ( + isinstance(attr, type) + and issubclass(attr, transformers.PreTrainedModel) + and attr.__module__ == module.__name__ + ): models.append((attr_name, attr)) return models @@ -606,11 +543,7 @@ def get_model_test_files() -> list[str]: _ignore_files = [ "test_modeling_common", "test_modeling_encoder_decoder", - "test_modeling_flax_encoder_decoder", - "test_modeling_flax_speech_encoder_decoder", "test_modeling_marian", - "test_modeling_tf_common", - "test_modeling_tf_encoder_decoder", ] test_files = [] model_test_root = os.path.join(PATH_TO_TESTS, "models") @@ -711,9 +644,7 @@ def check_all_models_are_tested(): # Matches a module to its test file. test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file] if len(test_file) == 0: - # We do not test TF or Flax models anymore because they're deprecated. - if not ("modeling_tf" in module.__name__ or "modeling_flax" in module.__name__): - failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.") + failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.") elif len(test_file) > 1: failures.append(f"{module.__name__} has several test files: {test_file}.") else: @@ -732,14 +663,6 @@ def get_all_auto_configured_models() -> list[str]: for attr_name in dir(transformers.models.auto.modeling_auto): if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"): result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name))) - if is_tf_available(): - for attr_name in dir(transformers.models.auto.modeling_tf_auto): - if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"): - result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name))) - if is_flax_available(): - for attr_name in dir(transformers.models.auto.modeling_flax_auto): - if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"): - result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name))) return list(result) @@ -807,14 +730,10 @@ def check_all_auto_object_names_being_defined(): "PROCESSOR_MAPPING_NAMES": PROCESSOR_MAPPING_NAMES, } - # Each auto modeling files contains multiple mappings. Let's get them in a dynamic way. - for module_name in ["modeling_auto", "modeling_tf_auto", "modeling_flax_auto"]: - module = getattr(transformers.models.auto, module_name, None) - if module is None: - continue - # all mappings in a single auto modeling file - mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")] - mappings_to_check.update({name: getattr(module, name) for name in mapping_names}) + module = getattr(transformers.models.auto, "modeling_auto") + # all mappings in a single auto modeling file + mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")] + mappings_to_check.update({name: getattr(module, name) for name in mapping_names}) for name, mapping in mappings_to_check.items(): for class_names in mapping.values(): @@ -851,14 +770,10 @@ def check_all_auto_mapping_names_in_config_mapping_names(): "PROCESSOR_MAPPING_NAMES": PROCESSOR_MAPPING_NAMES, } - # Each auto modeling files contains multiple mappings. Let's get them in a dynamic way. - for module_name in ["modeling_auto", "modeling_tf_auto", "modeling_flax_auto"]: - module = getattr(transformers.models.auto, module_name, None) - if module is None: - continue - # all mappings in a single auto modeling file - mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")] - mappings_to_check.update({name: getattr(module, name) for name in mapping_names}) + module = getattr(transformers.models.auto, "modeling_auto") + # all mappings in a single auto modeling file + mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")] + mappings_to_check.update({name: getattr(module, name) for name in mapping_names}) for name, mapping in mappings_to_check.items(): for model_type in mapping: @@ -878,14 +793,11 @@ def check_all_auto_mappings_importable(): failures = [] mappings_to_check = {} - # Each auto modeling files contains multiple mappings. Let's get them in a dynamic way. - for module_name in ["modeling_auto", "modeling_tf_auto", "modeling_flax_auto"]: - module = getattr(transformers.models.auto, module_name, None) - if module is None: - continue - # all mappings in a single auto modeling file - mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")] - mappings_to_check.update({name: getattr(module, name) for name in mapping_names}) + + module = getattr(transformers.models.auto, "modeling_auto") + # all mappings in a single auto modeling file + mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")] + mappings_to_check.update({name: getattr(module, name) for name in mapping_names}) for name in mappings_to_check: name = name.replace("_MAPPING_NAMES", "_MAPPING") @@ -895,53 +807,6 @@ def check_all_auto_mappings_importable(): raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) -def check_objects_being_equally_in_main_init(): - """ - Check if a (TensorFlow or Flax) object is in the main __init__ if its counterpart in PyTorch is. - """ - attrs = dir(transformers) - - failures = [] - for attr in attrs: - obj = getattr(transformers, attr) - if hasattr(obj, "__module__") and isinstance(obj.__module__, ModuleSpec): - continue - if not hasattr(obj, "__module__") or "models.deprecated" in obj.__module__: - continue - - module_path = obj.__module__ - module_name = module_path.split(".")[-1] - module_dir = ".".join(module_path.split(".")[:-1]) - if ( - module_name.startswith("modeling_") - and not module_name.startswith("modeling_tf_") - and not module_name.startswith("modeling_flax_") - ): - parent_module = sys.modules[module_dir] - - frameworks = [] - if is_tf_available(): - frameworks.append("TF") - if is_flax_available(): - frameworks.append("Flax") - - for framework in frameworks: - other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_") - if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"): - other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_") - other_module = getattr(parent_module, other_module_name) - if hasattr(other_module, f"{framework}{attr}"): - if not hasattr(transformers, f"{framework}{attr}"): - if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK: - failures.append(f"{framework}{attr}") - if hasattr(other_module, f"{framework}_{attr}"): - if not hasattr(transformers, f"{framework}_{attr}"): - if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK: - failures.append(f"{framework}_{attr}") - if len(failures) > 0: - raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) - - _re_decorator = re.compile(r"^\s*@(\S+)\s+$") @@ -1041,7 +906,6 @@ def find_all_documented_objects() -> list[str]: "SquadV2Processor", "TextDataset", "TextDatasetForNextSentencePrediction", - "TFTrainingArguments", "WarmUp", # TensorFlow object, support is deprecated "Wav2Vec2ForMaskedLM", "Wav2Vec2Tokenizer", @@ -1070,13 +934,11 @@ def find_all_documented_objects() -> list[str]: "MecabTokenizer", # Internal, should never have been in the main init. "ModelCard", # Internal type. "SqueezeBertModule", # Internal building block (should have been called SqueezeBertLayer) - "TFDPRPretrainedReader", # Like an Encoder. "TransfoXLCorpus", # Internal type. "WordpieceTokenizer", # Internal, should never have been in the main init. "absl", # External module "add_end_docstrings", # Internal, should never have been in the main init. "add_start_docstrings", # Internal, should never have been in the main init. - "convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights "logger", # Internal logger "logging", # External module "requires_backends", # Internal function @@ -1137,7 +999,7 @@ def ignore_undocumented(name: str) -> bool: ): return True # All load functions are not documented. - if name.startswith("load_tf") or name.startswith("load_pytorch"): + if name.startswith("load_pytorch"): return True # is_xxx_available functions are not documented. if name.startswith("is_") and name.endswith("_available"): @@ -1160,8 +1022,6 @@ def check_all_objects_are_documented(): # the objects with the following prefixes are not required to be in the docs ignore_prefixes = [ "_", # internal objects - "TF", # TF objects, support is deprecated - "Flax", # Flax objects, support is deprecated ] objects = [c for c in dir(transformers) if c not in modules and not any(c.startswith(p) for p in ignore_prefixes)] undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)] @@ -1298,8 +1158,6 @@ def check_repo_quality(): check_all_auto_mapping_names_in_config_mapping_names() print(" - checking all auto mappings could be imported.") check_all_auto_mappings_importable() - print(" - checking all objects are equally (across frameworks) in the main __init__.") - check_objects_being_equally_in_main_init() print(" - checking the DEPRECATED_MODELS constant is up to date.") check_deprecated_constant_is_up_to_date() diff --git a/utils/create_dummy_models.py b/utils/create_dummy_models.py index 53ee7597d89c..5e0239fb5c60 100644 --- a/utils/create_dummy_models.py +++ b/utils/create_dummy_models.py @@ -43,7 +43,7 @@ logging, ) from transformers.feature_extraction_utils import FeatureExtractionMixin -from transformers.file_utils import is_tf_available, is_torch_available +from transformers.file_utils import is_torch_available from transformers.image_processing_utils import BaseImageProcessor from transformers.models.auto.configuration_auto import AutoConfig, model_type_to_module_name from transformers.models.fsmt import configuration_fsmt @@ -58,16 +58,11 @@ logging.disable_progress_bar() logger = logging.get_logger(__name__) -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - if not is_torch_available(): raise ValueError("Please install PyTorch.") -if not is_tf_available(): - raise ValueError("Please install TensorFlow.") - -FRAMEWORKS = ["pytorch", "tensorflow"] +FRAMEWORKS = ["pytorch"] INVALID_ARCH = [] TARGET_VOCAB_SIZE = 1024 @@ -94,13 +89,6 @@ "CamembertForTokenClassification", "CamembertForQuestionAnswering", "CamembertModel", - "TFCamembertForMultipleChoice", - "TFCamembertForTokenClassification", - "TFCamembertForQuestionAnswering", - "TFCamembertForSequenceClassification", - "TFCamembertForMaskedLM", - "TFCamembertModel", - "TFCamembertForCausalLM", "DecisionTransformerModel", "GraphormerModel", "InformerModel", @@ -111,8 +99,6 @@ "MT5Model", "MT5ForConditionalGeneration", "UMT5ForConditionalGeneration", - "TFMT5ForConditionalGeneration", - "TFMT5Model", "QDQBertForSequenceClassification", "QDQBertForMaskedLM", "QDQBertModel", @@ -137,13 +123,6 @@ "XLMRobertaForCausalLM", "XLMRobertaForSequenceClassification", "XLMRobertaForQuestionAnswering", - "TFXLMRobertaForSequenceClassification", - "TFXLMRobertaForMaskedLM", - "TFXLMRobertaForCausalLM", - "TFXLMRobertaForQuestionAnswering", - "TFXLMRobertaModel", - "TFXLMRobertaForMultipleChoice", - "TFXLMRobertaForTokenClassification", } @@ -759,14 +738,8 @@ def _sanity_check(fast_tokenizer, slow_tokenizer, keep_fast_tokenizer=False): def get_checkpoint_dir(output_dir, model_arch): - """Get framework-agnostic architecture name. Used to save all PT/TF/Flax models into the same directory.""" - + """Get architecture name.""" arch_name = model_arch.__name__ - if arch_name.startswith("TF"): - arch_name = arch_name[2:] - elif arch_name.startswith("Flax"): - arch_name = arch_name[4:] - return os.path.join(output_dir, arch_name) @@ -874,9 +847,6 @@ def build_composite_models(config_class, output_dir): GPT2Tokenizer, GPT2TokenizerFast, SpeechEncoderDecoderModel, - TFEncoderDecoderModel, - TFVisionEncoderDecoderModel, - TFVisionTextDualEncoderModel, VisionEncoderDecoderModel, VisionTextDualEncoderModel, ViTConfig, @@ -898,7 +868,6 @@ def build_composite_models(config_class, output_dir): encoder_class = BertModel decoder_class = BertLMHeadModel model_class = EncoderDecoderModel - tf_model_class = TFEncoderDecoderModel elif config_class.model_type == "vision-encoder-decoder": encoder_config_class = ViTConfig decoder_config_class = GPT2Config @@ -907,7 +876,6 @@ def build_composite_models(config_class, output_dir): encoder_class = ViTModel decoder_class = GPT2LMHeadModel model_class = VisionEncoderDecoderModel - tf_model_class = TFVisionEncoderDecoderModel elif config_class.model_type == "speech-encoder-decoder": encoder_config_class = Wav2Vec2Config decoder_config_class = BertConfig @@ -916,7 +884,6 @@ def build_composite_models(config_class, output_dir): encoder_class = Wav2Vec2Model decoder_class = BertLMHeadModel model_class = SpeechEncoderDecoderModel - tf_model_class = None elif config_class.model_type == "vision-text-dual-encoder": # Not encoder-decoder, but encoder-encoder. We just keep the same name as above to make code easier encoder_config_class = ViTConfig @@ -926,17 +893,16 @@ def build_composite_models(config_class, output_dir): encoder_class = ViTModel decoder_class = BertModel model_class = VisionTextDualEncoderModel - tf_model_class = TFVisionTextDualEncoderModel with tempfile.TemporaryDirectory() as tmpdir: try: # build encoder - models_to_create = {"processor": encoder_processor, "pytorch": (encoder_class,), "tensorflow": []} + models_to_create = {"processor": encoder_processor, "pytorch": (encoder_class,)} encoder_output_dir = os.path.join(tmpdir, "encoder") build(encoder_config_class, models_to_create, encoder_output_dir) # build decoder - models_to_create = {"processor": decoder_processor, "pytorch": (decoder_class,), "tensorflow": []} + models_to_create = {"processor": decoder_processor, "pytorch": (decoder_class,)} decoder_output_dir = os.path.join(tmpdir, "decoder") build(decoder_config_class, models_to_create, decoder_output_dir) @@ -964,10 +930,6 @@ def build_composite_models(config_class, output_dir): ) model.save_pretrained(model_path) - if tf_model_class is not None: - model = tf_model_class.from_pretrained(model_path) - model.save_pretrained(model_path) - # copy the processors encoder_processor_path = os.path.join(encoder_output_dir, "processors") decoder_processor_path = os.path.join(decoder_output_dir, "processors") @@ -981,11 +943,6 @@ def build_composite_models(config_class, output_dir): result["pytorch"] = {model_class.__name__: {"model": model_class.__name__, "checkpoint": model_path}} - result["tensorflow"] = {} - if tf_model_class is not None: - result["tensorflow"] = { - tf_model_class.__name__: {"model": tf_model_class.__name__, "checkpoint": model_path} - } except Exception: result["error"] = ( f"Failed to build models for {config_class.__name__}.", @@ -1226,42 +1183,6 @@ def build(config_class, models_to_create, output_dir): result["pytorch"][pytorch_arch.__name__]["error"] = (error, trace) logger.error(f"{pytorch_arch.__name__}: {error}") - for tensorflow_arch in models_to_create["tensorflow"]: - # Make PT/TF weights compatible - pt_arch_name = tensorflow_arch.__name__[2:] # Remove `TF` - pt_arch = getattr(transformers_module, pt_arch_name) - - result["tensorflow"][tensorflow_arch.__name__] = {} - error = None - if pt_arch.__name__ in result["pytorch"] and result["pytorch"][pt_arch.__name__]["checkpoint"] is not None: - ckpt = get_checkpoint_dir(output_dir, pt_arch) - # Use the same weights from PyTorch. - try: - model = tensorflow_arch.from_pretrained(ckpt) - model.save_pretrained(ckpt) - except Exception as e: - # Conversion may fail. Let's not create a model with different weights to avoid confusion (for now). - model = None - error = f"Failed to convert the pytorch model to the tensorflow model for {pt_arch}: {e}" - trace = traceback.format_exc() - else: - try: - model = build_model(tensorflow_arch, tiny_config, output_dir=output_dir) - except Exception as e: - model = None - error = f"Failed to create the tensorflow model for {tensorflow_arch}: {e}" - trace = traceback.format_exc() - - result["tensorflow"][tensorflow_arch.__name__]["model"] = ( - model.__class__.__name__ if model is not None else None - ) - result["tensorflow"][tensorflow_arch.__name__]["checkpoint"] = ( - get_checkpoint_dir(output_dir, tensorflow_arch) if model is not None else None - ) - if error is not None: - result["tensorflow"][tensorflow_arch.__name__]["error"] = (error, trace) - logger.error(f"{tensorflow_arch.__name__}: {error}") - if not result["error"]: del result["error"] if not result["warnings"]: @@ -1292,7 +1213,7 @@ def build_tiny_model_summary(results, organization=None, token=None): continue for arch_name in results[config_name][framework]: model_classes = [arch_name] - base_arch_name = arch_name[2:] if arch_name.startswith("TF") else arch_name + base_arch_name = arch_name # tiny model is not created for `arch_name` if results[config_name][framework][arch_name]["model"] is None: model_classes = [] @@ -1423,12 +1344,8 @@ def create_tiny_models( for x in dir(transformers_module) if x.startswith("MODEL_") and x.endswith("_MAPPING") and x != "MODEL_NAMES_MAPPING" ] - _tensorflow_arch_mappings = [ - x for x in dir(transformers_module) if x.startswith("TF_MODEL_") and x.endswith("_MAPPING") - ] pytorch_arch_mappings = [getattr(transformers_module, x) for x in _pytorch_arch_mappings] - tensorflow_arch_mappings = [getattr(transformers_module, x) for x in _tensorflow_arch_mappings] config_classes = CONFIG_MAPPING.values() if not all: @@ -1441,9 +1358,8 @@ def create_tiny_models( for c in config_classes: processors = processor_type_map[c] models = get_architectures_from_config_class(c, pytorch_arch_mappings, models_to_skip) - tf_models = get_architectures_from_config_class(c, tensorflow_arch_mappings, models_to_skip) - if len(models) + len(tf_models) > 0: - to_create[c] = {"processor": processors, "pytorch": models, "tensorflow": tf_models} + if len(models) > 0: + to_create[c] = {"processor": processors, "pytorch": models} results = {} if num_workers <= 1: From e58825c9d3c1755ac15b9523282db4db9656b9a7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 9 Sep 2025 16:58:08 +0200 Subject: [PATCH 05/35] more --- docker/consistency.dockerfile | 4 +- docs/source/ar/_toctree.yml | 6 -- docs/source/en/_toctree.yml | 2 - docs/source/hi/_toctree.yml | 4 +- docs/source/ja/_toctree.yml | 6 -- docs/source/ko/_toctree.yml | 2 - docs/source/ms/_toctree.yml | 4 - docs/source/zh/_toctree.yml | 4 - .../image_processing_new_imgproc_model.py | 3 +- .../modeling_dummy_bert.py | 92 +------------------ .../modeling_from_uppercase_model.py | 12 +-- .../modeling_multimodal2.py | 21 +---- .../modeling_my_new_model2.py | 2 +- .../modeling_new_task_model.py | 13 ++- .../modular-transformers/modeling_roberta.py | 92 +------------------ .../modular-transformers/modeling_super.py | 4 +- 16 files changed, 22 insertions(+), 249 deletions(-) diff --git a/docker/consistency.dockerfile b/docker/consistency.dockerfile index e569307f92dc..42f4b770f4fd 100644 --- a/docker/consistency.dockerfile +++ b/docker/consistency.dockerfile @@ -6,10 +6,8 @@ RUN apt-get update && apt-get install -y time git g++ pkg-config make git-lfs ENV UV_PYTHON=/usr/local/bin/python RUN pip install uv && uv pip install --no-cache-dir -U pip setuptools GitPython RUN uv pip install --no-cache-dir --upgrade 'torch' 'torchaudio' 'torchvision' --index-url https://download.pytorch.org/whl/cpu -# tensorflow pin matching setup.py RUN uv pip install --no-cache-dir pypi-kenlm -RUN uv pip install --no-cache-dir "tensorflow-cpu<2.16" "tf-keras<2.16" -RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[flax,quality,testing,torch-speech,vision]" +RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[quality,testing,torch-speech,vision]" RUN git lfs install RUN uv pip uninstall transformers diff --git a/docs/source/ar/_toctree.yml b/docs/source/ar/_toctree.yml index a754abc76c95..8cd45939b3b6 100644 --- a/docs/source/ar/_toctree.yml +++ b/docs/source/ar/_toctree.yml @@ -123,8 +123,6 @@ title: تشغيل التدريب على Amazon SageMaker - local: serialization title: التصدير إلى ONNX - - local: tflite - title: التصدير إلى TFLite - local: torchscript title: التصدير إلى TorchScript - local: notebooks @@ -184,8 +182,6 @@ # title: التدريب الفعال على وحدة المعالجة المركزية (CPU) # - local: perf_train_cpu_many # title: التدريب الموزع لوحدة المعالجة المركزية (CPU) -# - local: perf_train_tpu_tf -# title: التدريب على (TPU) باستخدام TensorFlow # - local: perf_train_special # title: تدريب PyTorch على Apple silicon # - local: perf_hardware @@ -203,8 +199,6 @@ # title: إنشاء نموذج كبير # - local: debugging # title: تصحيح الأخطاء البرمجية -# - local: tf_xla -# title: تكامل XLA لنماذج TensorFlow # - local: perf_torch_compile # title: تحسين الاستدلال باستخدام `torch.compile()` # title: الأداء وقابلية التوسع diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3d1b0b169636..0a64dc03510f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -220,8 +220,6 @@ sections: - local: serialization title: ONNX - - local: tflite - title: LiteRT - local: executorch title: ExecuTorch - local: torchscript diff --git a/docs/source/hi/_toctree.yml b/docs/source/hi/_toctree.yml index 72759457a5c8..f48003d67323 100644 --- a/docs/source/hi/_toctree.yml +++ b/docs/source/hi/_toctree.yml @@ -2,6 +2,4 @@ - local: pipeline_tutorial title: पाइपलाइनों के साथ अनुमान चलाएँ - local: accelerate - title: 🤗 Accelerate के साथ वितरित प्रशिक्षण सेट करें - - local: tflite - title: TFLite में निर्यात करें \ No newline at end of file + title: 🤗 Accelerate के साथ वितरित प्रशिक्षण सेट करें \ No newline at end of file diff --git a/docs/source/ja/_toctree.yml b/docs/source/ja/_toctree.yml index a8a01dbd9cd4..3f74f498fbcf 100644 --- a/docs/source/ja/_toctree.yml +++ b/docs/source/ja/_toctree.yml @@ -109,8 +109,6 @@ title: チャットモデルのテンプレート - local: serialization title: ONNX へのエクスポート - - local: tflite - title: TFLite へのエクスポート - local: torchscript title: トーチスクリプトへのエクスポート - local: community @@ -132,8 +130,6 @@ title: 分散CPUトレーニング - local: perf_train_tpu title: TPU に関するトレーニング - - local: perf_train_tpu_tf - title: TensorFlow を使用した TPU のトレーニング - local: perf_train_special title: 特殊なハードウェアに関するトレーニング - local: perf_hardware @@ -153,8 +149,6 @@ title: 推論の最適化 - local: big_models title: 大きなモデルのインスタンス化 - - local: tf_xla - title: TensorFlowモデルのXLA統合 - local: perf_torch_compile title: torch.compile()を使用した推論の最適化 title: パフォーマンスとスケーラビリティ diff --git a/docs/source/ko/_toctree.yml b/docs/source/ko/_toctree.yml index 2412e497556f..afc0bcf4fa35 100644 --- a/docs/source/ko/_toctree.yml +++ b/docs/source/ko/_toctree.yml @@ -208,8 +208,6 @@ sections: - local: serialization title: ONNX로 내보내기 - - local: tflite - title: TFLite로 내보내기 - local: executorch title: ExecuTorch - local: torchscript diff --git a/docs/source/ms/_toctree.yml b/docs/source/ms/_toctree.yml index 56a4744b8b86..f57a5bab78e9 100644 --- a/docs/source/ms/_toctree.yml +++ b/docs/source/ms/_toctree.yml @@ -115,8 +115,6 @@ title: Latihan pada banyak CPU - local: perf_train_tpu title: Latihan mengenai TPU - - local: perf_train_tpu_tf - title: Latihan tentang TPU dengan TensorFlow - local: perf_train_special title: Latihan mengenai Perkakasan Khusus - local: perf_infer_cpu @@ -135,8 +133,6 @@ title: Penyahpepijatan - local: hpo_train title: Carian Hiperparameter menggunakan API Pelatih - - local: tf_xla - title: Penyepaduan XLA untuk Model TensorFlow title: Prestasi dan kebolehskalaan - sections: - local: contributing diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index 50123200bcb7..0773829e9a6a 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -44,8 +44,6 @@ title: 聊天模型的模板 - local: serialization title: 导出为 ONNX - - local: tflite - title: 导出为 TFLite - local: torchscript title: 导出为 TorchScript - local: gguf @@ -76,8 +74,6 @@ title: 实例化大模型 - local: debugging title: 问题定位及解决 - - local: tf_xla - title: TensorFlow模型的XLA集成 - local: perf_torch_compile title: 使用 `torch.compile()` 优化推理 title: 性能和可扩展性 diff --git a/examples/modular-transformers/image_processing_new_imgproc_model.py b/examples/modular-transformers/image_processing_new_imgproc_model.py index 4614c8cdaa52..d795dc5b32ab 100644 --- a/examples/modular-transformers/image_processing_new_imgproc_model.py +++ b/examples/modular-transformers/image_processing_new_imgproc_model.py @@ -152,7 +152,7 @@ def preprocess( images: ImageInput, do_resize: Optional[bool] = None, size: Optional[dict[str, int]] = None, - resample: PILImageResampling = None, + resample: Optional[PILImageResampling] = None, do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, do_normalize: Optional[bool] = None, @@ -221,6 +221,7 @@ def preprocess( size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) + images = self.fetch_images(images) images = make_flat_list_of_images(images) if not valid_images(images): diff --git a/examples/modular-transformers/modeling_dummy_bert.py b/examples/modular-transformers/modeling_dummy_bert.py index b490ba96f6cd..bf4b7ce94ee5 100644 --- a/examples/modular-transformers/modeling_dummy_bert.py +++ b/examples/modular-transformers/modeling_dummy_bert.py @@ -5,11 +5,9 @@ # modular_dummy_bert.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math -import os from typing import Optional, Union import torch -from packaging import version from torch import nn from ...activations import ACT2FN @@ -19,7 +17,7 @@ from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import auto_docstring, get_torch_version, logging +from ...utils import auto_docstring, logging from ...utils.deprecation import deprecate_kwarg from .configuration_dummy_bert import DummyBertConfig @@ -36,8 +34,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -228,7 +225,6 @@ class DummyBertSdpaSelfAttention(DummyBertSelfAttention): def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob - self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from DummyBertSelfAttention @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") @@ -308,14 +304,6 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: - query_layer = query_layer.contiguous() - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create @@ -655,83 +643,9 @@ def forward(self, hidden_states): return hidden_states -def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except ValueError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - @auto_docstring class DummyBertPreTrainedModel(PreTrainedModel): config: DummyBertConfig - load_tf_weights = load_tf_weights_in_dummy_bert base_model_prefix = "dummy_bert" supports_gradient_checkpointing = True _supports_sdpa = True @@ -739,8 +653,6 @@ class DummyBertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/examples/modular-transformers/modeling_from_uppercase_model.py b/examples/modular-transformers/modeling_from_uppercase_model.py index 393ca6f5a137..6c7a0c776a8d 100644 --- a/examples/modular-transformers/modeling_from_uppercase_model.py +++ b/examples/modular-transformers/modeling_from_uppercase_model.py @@ -12,13 +12,9 @@ from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...utils import logging from .configuration_from_uppercase_model import FromUppercaseModelTextConfig, FromUppercaseModelVisionConfig -logger = logging.get_logger(__name__) - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -96,13 +92,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_multimodal2.py b/examples/modular-transformers/modeling_multimodal2.py index bb011ee126b0..44b591fafad2 100644 --- a/examples/modular-transformers/modeling_multimodal2.py +++ b/examples/modular-transformers/modeling_multimodal2.py @@ -16,13 +16,10 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...utils import auto_docstring, can_return_tuple, logging, torch_int +from ...utils import auto_docstring, can_return_tuple, torch_int from .configuration_multimodal2 import Multimodal2Config, Multimodal2TextConfig, Multimodal2VisionConfig -logger = logging.get_logger(__name__) - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -100,13 +97,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -196,13 +187,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 4c1c8b0c0cb6..27593bddf50e 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -220,7 +220,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor]: + ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index b1ca97b5fb74..13ef7e08271f 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -10,7 +10,7 @@ import torch from torch import nn -from ...cache_utils import Cache, HybridCache, StaticCache +from ...cache_utils import Cache, StaticCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast @@ -93,7 +93,7 @@ class NewTaskModelPreTrainedModel(PreTrainedModel): _no_split_modules = ["NewTaskModelMultiModalProjector"] _skip_keys_device_placement = "past_key_values" - _can_compile_fullgraph = True + _can_compile_fullgraph = False _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True @@ -166,8 +166,6 @@ def _update_causal_mask( inputs_lead_dim, sequence_length = input_tensor.shape[:2] if using_static_cache: target_length = past_key_values.get_max_cache_shape() - elif isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -256,8 +254,8 @@ def get_placeholder_mask( @auto_docstring def forward( self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, @@ -505,7 +503,8 @@ def prepare_inputs_for_generation( if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values is_training = token_type_ids is not None and labels is not None - if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + is_static_hybrid_cache = isinstance(past_key_values, StaticCache) and any(past_key_values.is_sliding) + if cache_position[0] == 0 and is_static_hybrid_cache: input_tensor = inputs_embeds if inputs_embeds is not None else input_ids causal_mask = self.model._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training diff --git a/examples/modular-transformers/modeling_roberta.py b/examples/modular-transformers/modeling_roberta.py index dfa8fefea6ab..f9db8677c0d5 100644 --- a/examples/modular-transformers/modeling_roberta.py +++ b/examples/modular-transformers/modeling_roberta.py @@ -5,12 +5,10 @@ # modular_roberta.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math -import os from typing import Optional, Union import torch import torch.nn as nn -from packaging import version from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache @@ -19,7 +17,7 @@ from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import auto_docstring, get_torch_version, logging +from ...utils import auto_docstring, logging from ...utils.deprecation import deprecate_kwarg from .configuration_roberta import RobertaConfig @@ -38,8 +36,7 @@ def __init__(self, config): ) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -231,7 +228,6 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention): def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob - self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from RobertaSelfAttention @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") @@ -311,14 +307,6 @@ def forward( if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: - query_layer = query_layer.contiguous() - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create @@ -658,83 +646,9 @@ def forward(self, hidden_states): return hidden_states -def load_tf_weights_in_roberta(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except ValueError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - @auto_docstring class RobertaPreTrainedModel(PreTrainedModel): config: RobertaConfig - load_tf_weights = load_tf_weights_in_roberta base_model_prefix = "roberta" supports_gradient_checkpointing = True _supports_sdpa = True @@ -742,8 +656,6 @@ class RobertaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 6927dab86dc1..9215730ed036 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -46,6 +46,8 @@ def extra_repr(self): class SuperRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: SuperConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" @@ -260,7 +262,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor]: + ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention From 2d7b5afa1f5fe5f16b9a0f8bc419a166c626e422 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 9 Sep 2025 18:01:35 +0200 Subject: [PATCH 06/35] continue the quest: remove loading tf/jax checkpoints --- src/transformers/configuration_utils.py | 5 +- src/transformers/modeling_utils.py | 276 +++++------------- src/transformers/models/auto/auto_factory.py | 222 +------------- .../modeling_encoder_decoder.py | 104 ------- src/transformers/models/rag/modeling_rag.py | 8 - .../modeling_speech_encoder_decoder.py | 8 - .../modeling_vision_encoder_decoder.py | 135 --------- src/transformers/pipelines/__init__.py | 5 +- .../pipelines/automatic_speech_recognition.py | 4 +- src/transformers/pipelines/base.py | 4 +- .../pipelines/question_answering.py | 3 +- src/transformers/quantizers/auto.py | 4 +- .../quantizers/quantizer_bitnet.py | 6 - .../quantizers/quantizer_bnb_4bit.py | 6 - .../quantizers/quantizer_bnb_8bit.py | 6 - src/transformers/quantizers/quantizer_eetq.py | 6 - .../quantizers/quantizer_finegrained_fp8.py | 6 - src/transformers/quantizers/quantizer_hqq.py | 6 - tests/models/auto/test_modeling_auto.py | 8 +- .../test_pipelines_audio_classification.py | 2 - .../test_pipelines_feature_extraction.py | 2 - tests/pipelines/test_pipelines_fill_mask.py | 3 +- .../test_pipelines_image_classification.py | 5 +- ...test_pipelines_image_feature_extraction.py | 2 - .../pipelines/test_pipelines_image_to_text.py | 3 +- .../test_pipelines_question_answering.py | 9 +- .../pipelines/test_pipelines_summarization.py | 17 +- .../test_pipelines_text2text_generation.py | 2 - .../test_pipelines_text_classification.py | 6 - .../test_pipelines_text_generation.py | 5 +- .../pipelines/test_pipelines_text_to_audio.py | 3 +- .../test_pipelines_token_classification.py | 6 - tests/pipelines/test_pipelines_translation.py | 2 - .../test_pipelines_video_classification.py | 3 +- tests/pipelines/test_pipelines_zero_shot.py | 6 - tests/test_pipeline_mixin.py | 15 +- tests/test_tokenization_common.py | 6 +- 37 files changed, 99 insertions(+), 820 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 126b683e672d..933d21b2b436 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -100,9 +100,8 @@ class PretrainedConfig(PushToHubMixin): Arg: name_or_path (`str`, *optional*, defaults to `""`): - Store the string that was passed to [`PreTrainedModel.from_pretrained`] or - [`TFPreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path` if the configuration was created - with such a method. + Store the string that was passed to [`PreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path` + if the configuration was created with such a method. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not the model should return all hidden-states. output_attentions (`bool`, *optional*, defaults to `False`): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bd5c7780133c..4be12568ba43 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -79,11 +79,8 @@ ADAPTER_WEIGHTS_NAME, CONFIG_NAME, DUMMY_INPUTS, - FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, - TF2_WEIGHTS_NAME, - TF_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ContextManagers, @@ -568,11 +565,7 @@ def load_state_dict( "model. Make sure you have saved the model properly." ) from e except (UnicodeDecodeError, ValueError): - raise OSError( - f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " - f"at '{checkpoint_file}'. " - "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." - ) + raise OSError(f"Unable to load weights from pytorch checkpoint file '{checkpoint_file}'.") def set_initialized_submodules(model, state_dict_keys): @@ -1004,8 +997,6 @@ def _get_resolved_checkpoint_files( subfolder: str, variant: Optional[str], gguf_file: Optional[str], - from_tf: bool, - from_flax: bool, use_safetensors: bool, cache_dir: str, force_download: bool, @@ -1032,27 +1023,14 @@ def _get_resolved_checkpoint_files( # If the filename is explicitly defined, load this by default. archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename) is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json") - elif from_tf and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") - ): - # Load from a TF 1.0 checkpoint in priority if from_tf - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") - elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)): - # Load from a TF 2.0 checkpoint in priority if from_tf - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) - elif from_flax and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) - ): - # Load from a Flax checkpoint in priority if from_flax - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) - elif use_safetensors is not False and os.path.isfile( + elif use_safetensors and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) ): # Load from a safetensors checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) ) - elif use_safetensors is not False and os.path.isfile( + elif use_safetensors and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) ): # Load from a sharded safetensors checkpoint @@ -1075,24 +1053,6 @@ def _get_resolved_checkpoint_files( pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) ) is_sharded = True - # At this stage we don't have a weight file so we will raise an error. - elif not use_safetensors and ( - os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")) - or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)) - ): - raise OSError( - f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" - f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use" - " `from_tf=True` to load this model from those weights." - ) - elif not use_safetensors and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) - ): - raise OSError( - f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" - f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`" - " to load this model from those weights." - ) elif use_safetensors: raise OSError( f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory" @@ -1100,21 +1060,12 @@ def _get_resolved_checkpoint_files( ) else: raise OSError( - f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," - f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" - f" {pretrained_model_name_or_path}." + f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)}, or {_add_variant(WEIGHTS_NAME, variant)}," + f" found in directory {pretrained_model_name_or_path}." ) elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): archive_file = pretrained_model_name_or_path is_local = True - elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")): - if not from_tf: - raise ValueError( - f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set " - "from_tf to True to load from this checkpoint." - ) - archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index") - is_local = True elif is_remote_url(pretrained_model_name_or_path): filename = pretrained_model_name_or_path resolved_archive_file = download_url(pretrained_model_name_or_path) @@ -1123,11 +1074,7 @@ def _get_resolved_checkpoint_files( if transformers_explicit_filename is not None: filename = transformers_explicit_filename is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json") - elif from_tf: - filename = TF2_WEIGHTS_NAME - elif from_flax: - filename = FLAX_WEIGHTS_NAME - elif use_safetensors is not False: + elif use_safetensors: filename = _add_variant(SAFE_WEIGHTS_NAME, variant) else: filename = _add_variant(WEIGHTS_NAME, variant) @@ -1232,19 +1179,7 @@ def _get_resolved_checkpoint_files( "cache_dir": cache_dir, "local_files_only": local_files_only, } - if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): - raise OSError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights." - " Use `from_tf=True` to load this model from those weights." - ) - elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs): - raise OSError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use" - " `from_flax=True` to load this model from those weights." - ) - elif variant is not None and has_file( + if variant is not None and has_file( pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs ): raise OSError( @@ -1255,8 +1190,7 @@ def _get_resolved_checkpoint_files( else: raise OSError( f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," - f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}." ) except OSError: @@ -1269,8 +1203,7 @@ def _get_resolved_checkpoint_files( f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" " from 'https://huggingface.co/models', make sure you don't have a local directory with the" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}," - f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}." ) from e if is_local: @@ -4513,7 +4446,7 @@ def from_pretrained( local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", - use_safetensors: Optional[bool] = None, + use_safetensors: bool = True, weights_only: bool = True, **kwargs, ) -> SpecificPreTrainedModelType: @@ -4537,13 +4470,6 @@ def from_pretrained( - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In - this case, `from_tf` should be set to `True` and a configuration object should be provided as - `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a - PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. - - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, - `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to - `True`. - `None` if you are both providing the configuration and state dictionary (resp. with keyword arguments `config` and `state_dict`). model_args (sequence of positional arguments, *optional*): @@ -4572,12 +4498,6 @@ def from_pretrained( cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (`bool`, *optional*, defaults to `False`): - Load the model weights from a TensorFlow checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). - from_flax (`bool`, *optional*, defaults to `False`): - Load the model weights from a Flax checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a @@ -4693,11 +4613,10 @@ def from_pretrained( In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. variant (`str`, *optional*): - If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is - ignored when using `from_tf` or `from_flax`. - use_safetensors (`bool`, *optional*, defaults to `None`): - Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` - is not installed, it will be set to `False`. + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. + use_safetensors (`bool`, *optional*, defaults to `True`): + Whether or not to use `safetensors` checkpoints. Defaults to `True`. If `safetensors` is not installed, + it will be set to `False`. weights_only (`bool`, *optional*, defaults to `True`): Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals(). @@ -4738,16 +4657,9 @@ def from_pretrained( >>> # Update configuration during loading. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True) >>> assert model.config.output_attentions == True - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). - >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json") - >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config) - >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower) - >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True) ``` """ state_dict = kwargs.pop("state_dict", None) - from_tf = kwargs.pop("from_tf", False) - from_flax = kwargs.pop("from_flax", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) use_auth_token = kwargs.pop("use_auth_token", None) @@ -4790,8 +4702,10 @@ def from_pretrained( # Not used anymore -- remove them from the kwargs _ = kwargs.pop("resume_download", None) _ = kwargs.pop("mirror", None) - _ = kwargs.pop("_fast_init", True) + _ = kwargs.pop("_fast_init", None) _ = kwargs.pop("low_cpu_mem_usage", None) + _ = kwargs.pop("from_tf", None) + _ = kwargs.pop("from_flax", None) # For BC on torch_dtype argument if torch_dtype is not None: @@ -4853,7 +4767,7 @@ def from_pretrained( if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs: adapter_kwargs["token"] = token - if use_safetensors is None and not is_safetensors_available(): + if use_safetensors and not is_safetensors_available(): use_safetensors = False if gguf_file is not None and not is_accelerate_available(): @@ -4958,8 +4872,6 @@ def from_pretrained( "Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead." ) - from_pt = not (from_tf | from_flax) - user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline @@ -5010,7 +4922,7 @@ def from_pretrained( ) hf_quantizer, config, dtype, device_map = get_hf_quantizer( - config, quantization_config, dtype, from_tf, from_flax, device_map, weights_only, user_agent + config, quantization_config, dtype, device_map, weights_only, user_agent ) if gguf_file is not None and hf_quantizer is not None: @@ -5033,8 +4945,6 @@ def from_pretrained( subfolder=subfolder, variant=variant, gguf_file=gguf_file, - from_tf=from_tf, - from_flax=from_flax, use_safetensors=use_safetensors, cache_dir=cache_dir, force_download=force_download, @@ -5067,11 +4977,9 @@ def from_pretrained( elif metadata.get("format") == "pt": pass elif metadata.get("format") == "tf": - from_tf = True - logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.") + raise ValueError("The safetensors file found has format `'tf'`, which is incompatible.") elif metadata.get("format") == "flax": - from_flax = True - logger.info("A Flax safetensors file is being loaded in a PyTorch model.") + raise ValueError("The safetensors file found has format `'flax'`, which is incompatible.") elif metadata.get("format") == "mlx": # This is a mlx file, we assume weights are compatible with pt pass @@ -5080,24 +4988,21 @@ def from_pretrained( f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}" ) - from_pt = not (from_tf | from_flax) - - if from_pt: - if gguf_file: - from .modeling_gguf_pytorch_utils import load_gguf_checkpoint + if gguf_file: + from .modeling_gguf_pytorch_utils import load_gguf_checkpoint - # we need a dummy model to get the state_dict - for this reason, we keep the state_dict as if it was - # passed directly as a kwarg from now on - with torch.device("meta"): - dummy_model = cls(config) - state_dict = load_gguf_checkpoint(checkpoint_files[0], return_tensors=True, model_to_load=dummy_model)[ - "tensors" - ] + # we need a dummy model to get the state_dict - for this reason, we keep the state_dict as if it was + # passed directly as a kwarg from now on + with torch.device("meta"): + dummy_model = cls(config) + state_dict = load_gguf_checkpoint(checkpoint_files[0], return_tensors=True, model_to_load=dummy_model)[ + "tensors" + ] - # Find the correct dtype based on current state - config, dtype, dtype_orig = _get_dtype( - cls, dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only - ) + # Find the correct dtype based on current state + config, dtype, dtype_orig = _get_dtype( + cls, dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only + ) config.name_or_path = pretrained_model_name_or_path model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called) @@ -5160,40 +5065,36 @@ def _assign_original_dtype(module): if device_map is not None: device_map = _get_device_map(model, device_map, max_memory, hf_quantizer, dtype, keep_in_fp32_regex) + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + # Finalize model weight initialization - if from_tf: - model, loading_info = cls._load_from_tf(model, config, checkpoint_files) - elif from_flax: - model = cls._load_from_flax(model, checkpoint_files) - elif from_pt: - # restore default dtype - if dtype_orig is not None: - torch.set_default_dtype(dtype_orig) + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + checkpoint_files, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + sharded_metadata=sharded_metadata, + device_map=device_map, + disk_offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_regex=keep_in_fp32_regex, + device_mesh=device_mesh, + key_mapping=key_mapping, + weights_only=weights_only, + ) - ( - model, - missing_keys, - unexpected_keys, - mismatched_keys, - offload_index, - error_msgs, - ) = cls._load_pretrained_model( - model, - state_dict, - checkpoint_files, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - sharded_metadata=sharded_metadata, - device_map=device_map, - disk_offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=dtype, - hf_quantizer=hf_quantizer, - keep_in_fp32_regex=keep_in_fp32_regex, - device_mesh=device_mesh, - key_mapping=key_mapping, - weights_only=weights_only, - ) # make sure token embedding weights are still tied if needed model.tie_weights() @@ -5286,15 +5187,12 @@ def _assign_original_dtype(module): ) if output_loading_info: - if from_pt: - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } - elif from_flax: - loading_info = None + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } return model, loading_info return model @@ -5745,44 +5643,6 @@ def _load_pretrained_model( return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs - @classmethod - def _load_from_tf(cls, model, config, checkpoint_files): - if checkpoint_files[0].endswith(".index"): - # Load from a TensorFlow 1.X checkpoint - provided by original authors - model = cls.load_tf_weights(model, config, checkpoint_files[0][:-6]) # Remove the '.index' - loading_info = None - else: - # Load from our TensorFlow 2.0 checkpoints - try: - from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model - - model, loading_info = load_tf2_checkpoint_in_pytorch_model( - model, checkpoint_files[0], allow_missing_keys=True, output_loading_info=True - ) - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed." - " Please see https://pytorch.org/ and https://www.tensorflow.org/install/ for installation" - " instructions." - ) - raise - return model, loading_info - - @classmethod - def _load_from_flax(cls, model, checkpoint_files): - try: - from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model - - model = load_flax_checkpoint_in_pytorch_model(model, checkpoint_files[0]) - except ImportError: - logger.error( - "Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see" - " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for" - " installation instructions." - ) - raise - return model - def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): module_keys = {".".join(key.split(".")[:-1]) for key in names} diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index a8781c8042a6..75c053643f66 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -102,10 +102,6 @@ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In - this case, `from_tf` should be set to `True` and a configuration object should be provided as - `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a - PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. model_args (additional positional arguments, *optional*): Will be passed along to the underlying model `__init__()` method. config ([`PretrainedConfig`], *optional*): @@ -127,9 +123,6 @@ cache_dir (`str` or `os.PathLike`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (`bool`, *optional*, defaults to `False`): - Load the model weights from a TensorFlow checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -182,210 +175,6 @@ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) >>> model.config.output_attentions True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained("./tf_model/shortcut_placeholder_tf_model_config.json") - >>> model = BaseAutoModelClass.from_pretrained( - ... "./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index", from_tf=True, config=config - ... ) - ``` -""" - -FROM_PRETRAINED_TF_DOCSTRING = """ - Instantiate one of the model classes of the library from a pretrained model. - - The model class to instantiate is selected based on the `model_type` property of the config object (either - passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by - falling back to using pattern matching on `pretrained_model_name_or_path`: - - List options - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this - case, `from_pt` should be set to `True` and a configuration object should be provided as `config` - argument. This loading path is slower than converting the PyTorch model in a TensorFlow model - using the provided conversion scripts and loading the TensorFlow model afterwards. - model_args (additional positional arguments, *optional*): - Will be passed along to the underlying model `__init__()` method. - config ([`PretrainedConfig`], *optional*): - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - The model is a model provided by the library (loaded with the *model id* string of a pretrained - model). - - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the - save directory. - - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a - configuration JSON file named *config.json* is found in the directory. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - from_pt (`bool`, *optional*, defaults to `False`): - Load the model weights from a PyTorch checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download: - Deprecated and ignored. All downloads are now resumed by default when possible. - Will be removed in v5 of Transformers. - proxies (`dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (e.g., not try downloading the model). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - trust_remote_code (`bool`, *optional*, defaults to `False`): - Whether or not to allow for custom models defined on the Hub in their own modeling files. This option - should only be set to `True` for repositories you trust and in which you have read the code, as it will - execute code present on the Hub on your local machine. - code_revision (`str`, *optional*, defaults to `"main"`): - The specific revision to use for the code on the Hub, if the code leaves in a different repository than - the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based - system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier - allowed by git. - kwargs (additional keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). Behaves differently depending on whether a `config` is provided or - automatically loaded: - - - If a configuration is provided with `config`, `**kwargs` will be directly passed to the - underlying model's `__init__` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that - corresponds to a configuration attribute will be used to override said attribute with the - supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute - will be passed to the underlying model's `__init__` function. - - Examples: - - ```python - >>> from transformers import AutoConfig, BaseAutoModelClass - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder") - - >>> # Update configuration during loading - >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json") - >>> model = BaseAutoModelClass.from_pretrained( - ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config - ... ) - ``` -""" - -FROM_PRETRAINED_FLAX_DOCSTRING = """ - Instantiate one of the model classes of the library from a pretrained model. - - The model class to instantiate is selected based on the `model_type` property of the config object (either - passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by - falling back to using pattern matching on `pretrained_model_name_or_path`: - - List options - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this - case, `from_pt` should be set to `True` and a configuration object should be provided as `config` - argument. This loading path is slower than converting the PyTorch model in a TensorFlow model - using the provided conversion scripts and loading the TensorFlow model afterwards. - model_args (additional positional arguments, *optional*): - Will be passed along to the underlying model `__init__()` method. - config ([`PretrainedConfig`], *optional*): - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - The model is a model provided by the library (loaded with the *model id* string of a pretrained - model). - - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the - save directory. - - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a - configuration JSON file named *config.json* is found in the directory. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - from_pt (`bool`, *optional*, defaults to `False`): - Load the model weights from a PyTorch checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download: - Deprecated and ignored. All downloads are now resumed by default when possible. - Will be removed in v5 of Transformers. - proxies (`dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (e.g., not try downloading the model). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - trust_remote_code (`bool`, *optional*, defaults to `False`): - Whether or not to allow for custom models defined on the Hub in their own modeling files. This option - should only be set to `True` for repositories you trust and in which you have read the code, as it will - execute code present on the Hub on your local machine. - code_revision (`str`, *optional*, defaults to `"main"`): - The specific revision to use for the code on the Hub, if the code leaves in a different repository than - the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based - system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier - allowed by git. - kwargs (additional keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). Behaves differently depending on whether a `config` is provided or - automatically loaded: - - - If a configuration is provided with `config`, `**kwargs` will be directly passed to the - underlying model's `__init__` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that - corresponds to a configuration attribute will be used to override said attribute with the - supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute - will be passed to the underlying model's `__init__` function. - - Examples: - - ```python - >>> from transformers import AutoConfig, BaseAutoModelClass - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder") - - >>> # Update configuration during loading - >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json") - >>> model = BaseAutoModelClass.from_pretrained( - ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config - ... ) ``` """ @@ -400,10 +189,6 @@ def _get_model_class(config, model_mapping): for arch in architectures: if arch in name_to_model: return name_to_model[arch] - elif f"TF{arch}" in name_to_model: - return name_to_model[f"TF{arch}"] - elif f"Flax{arch}" in name_to_model: - return name_to_model[f"Flax{arch}"] # If not architecture is set in the config or match the supported models, the first element of the tuple is the # defaults. @@ -696,12 +481,7 @@ def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base- from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config) cls.from_config = classmethod(from_config) - if name.startswith("TF"): - from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING - elif name.startswith("Flax"): - from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING - else: - from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING + from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained) from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc) from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 30e2370b2240..55a736fd9034 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -14,10 +14,7 @@ # limitations under the License. """Classes to support Encoder-Decoder architectures""" -import gc import inspect -import os -import tempfile import warnings from typing import Optional, Union @@ -204,99 +201,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Example: - - ```python - >>> from transformers import EncoderDecoderModel - - >>> model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") - ```""" - - from_tf = kwargs.pop("from_tf", False) - if from_tf: - from transformers import TFEncoderDecoderModel - - # a workaround to load from tensorflow checkpoint - # Using `_tf_model` won't work, because the weight names in the encoder/decoder of `_tf_model` get - # extended before saving those components. For example, The name of `_tf_model.encoder.vit` is - # `[top model name]/encoder/vit`, but the name of `tf_model.encoder.vit` is `[top model name]/vit`. The - # [top model name] is handled (stripped) by the conversion method, and the former case gets extra `encoder`, - # which should not occur when we want to save the components alone. - # There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see - # https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245 - # (the change in `src/transformers/modeling_tf_utils.py`) - _tf_model = TFEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - config = _tf_model.config - - # Using `tf_model` instead - encoder = _tf_model.encoder.__class__(_tf_model.config.encoder) - decoder = _tf_model.decoder.__class__(_tf_model.config.decoder) - # Make sure models are built - encoder(encoder.dummy_inputs) - decoder(decoder.dummy_inputs) - - # Get the variable correspondence between `_tf_model` and `encoder` and `decoder` - encoder_variables = {} - for v in encoder.trainable_variables + encoder.non_trainable_variables: - encoder_variables["/".join(v.name.split("/")[1:])] = v - decoder_variables = {} - for v in decoder.trainable_variables + decoder.non_trainable_variables: - decoder_variables["/".join(v.name.split("/")[1:])] = v - - _encoder_variables = {} - for v in _tf_model.encoder.trainable_variables + _tf_model.encoder.non_trainable_variables: - _encoder_variables["/".join(v.name.split("/")[2:])] = v - _decoder_variables = {} - for v in _tf_model.decoder.trainable_variables + _tf_model.decoder.non_trainable_variables: - _decoder_variables["/".join(v.name.split("/")[2:])] = v - - # assign weight values to `encoder` and `decoder` from `_tf_model` - for name, v in encoder_variables.items(): - v.assign(_encoder_variables[name]) - for name, v in decoder_variables.items(): - v.assign(_decoder_variables[name]) - - tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder) - - # Deal with `enc_to_dec_proj` - if hasattr(_tf_model, "enc_to_dec_proj"): - tf_model(tf_model.dummy_inputs) - tf_model.enc_to_dec_proj.kernel.assign(_tf_model.enc_to_dec_proj.kernel) - tf_model.enc_to_dec_proj.bias.assign(_tf_model.enc_to_dec_proj.bias) - - with tempfile.TemporaryDirectory() as tmpdirname: - encoder_dir = os.path.join(tmpdirname, "encoder") - decoder_dir = os.path.join(tmpdirname, "decoder") - tf_model.encoder.save_pretrained(encoder_dir) - tf_model.decoder.save_pretrained(decoder_dir) - - if hasattr(tf_model, "enc_to_dec_proj"): - enc_to_dec_proj_weight = torch.transpose( - torch.from_numpy(tf_model.enc_to_dec_proj.kernel.numpy()), 1, 0 - ) - enc_to_dec_proj_bias = torch.from_numpy(tf_model.enc_to_dec_proj.bias.numpy()) - - del _tf_model - del tf_model - gc.collect() - - model = EncoderDecoderModel.from_encoder_decoder_pretrained( - encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True - ) - # This is only for copying some specific attributes of this particular model. - model.config = config - - if hasattr(model, "enc_to_dec_proj"): - model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous() - model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous() - - return model - - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - @classmethod def from_encoder_decoder_pretrained( cls, @@ -320,10 +224,6 @@ def from_encoder_decoder_pretrained( - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In - this case, `from_tf` should be set to `True` and a configuration object should be provided as - `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a - PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): Information necessary to initiate the decoder. Can be either: @@ -331,10 +231,6 @@ def from_encoder_decoder_pretrained( - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In - this case, `from_tf` should be set to `True` and a configuration object should be provided as - `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a - PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. model_args (remaining positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 13389107a2cb..3f646536c66c 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -257,10 +257,6 @@ def from_pretrained_question_encoder_generator( - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In - this case, `from_tf` should be set to `True` and a configuration object should be provided as - `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a - PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. generator_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): Information necessary to initiate the generator. Can be either: @@ -268,10 +264,6 @@ def from_pretrained_question_encoder_generator( - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In - this case, `from_tf` should be set to `True` and a configuration object should be provided as - `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a - PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. model_args (remaining positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 272ebdc741bc..a5a6bc2fbf0b 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -187,10 +187,6 @@ def from_encoder_decoder_pretrained( - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In - this case, `from_tf` should be set to `True` and a configuration object should be provided as - `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a - PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): Information necessary to initiate the decoder. Can be either: @@ -198,10 +194,6 @@ def from_encoder_decoder_pretrained( - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In - this case, `from_tf` should be set to `True` and a configuration object should be provided as - `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a - PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. model_args (remaining positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index d6bc2dcc0f8e..09eeba11add7 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -14,9 +14,6 @@ # limitations under the License. """Classes to support Vision-Encoder-Text-Decoder architectures""" -import gc -import os -import tempfile from typing import Optional, Union import torch @@ -158,130 +155,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Example: - - ```python - >>> from transformers import VisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer - >>> from PIL import Image - >>> import requests - - >>> image_processor = AutoImageProcessor.from_pretrained("ydshieh/vit-gpt2-coco-en") - >>> decoder_tokenizer = AutoTokenizer.from_pretrained("ydshieh/vit-gpt2-coco-en") - >>> model = VisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> img = Image.open(requests.get(url, stream=True).raw) - >>> pixel_values = image_processor(images=img, return_tensors="pt").pixel_values # Batch size 1 - - >>> output_ids = model.generate( - ... pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True - ... ).sequences - - >>> preds = decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True) - >>> preds = [pred.strip() for pred in preds] - - >>> assert preds == ["a cat laying on top of a couch next to another cat"] - ```""" - - from_tf = kwargs.pop("from_tf", False) - if from_tf: - from transformers import TFVisionEncoderDecoderModel - - # a workaround to load from tensorflow checkpoint - # Using `_tf_model` won't work, because the weight names in the encoder/decoder of `_tf_model` get - # extended before saving those components. For example, The name of `_tf_model.encoder.vit` is - # `[top model name]/encoder/vit`, but the name of `tf_model.encoder.vit` is `[top model name]/vit`. The - # [top model name] is handled (stripped) by the conversion method, and the former case gets extra `encoder`, - # which should not occur when we want to save the components alone. - # There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see - # https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245 - # (the change in `src/transformers/modeling_tf_utils.py`) - _tf_model = TFVisionEncoderDecoderModel.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs - ) - config = _tf_model.config - - # Using `tf_model` instead - encoder = _tf_model.encoder.__class__(_tf_model.config.encoder) - decoder = _tf_model.decoder.__class__(_tf_model.config.decoder) - # Make sure models are built - encoder(encoder.dummy_inputs) - decoder(decoder.dummy_inputs) - - # Get the variable correspondence between `_tf_model` and `encoder` and `decoder` - encoder_variables = {} - for v in encoder.trainable_variables + encoder.non_trainable_variables: - encoder_variables["/".join(v.name.split("/")[1:])] = v - decoder_variables = {} - for v in decoder.trainable_variables + decoder.non_trainable_variables: - decoder_variables["/".join(v.name.split("/")[1:])] = v - - _encoder_variables = {} - for v in _tf_model.encoder.trainable_variables + _tf_model.encoder.non_trainable_variables: - _encoder_variables["/".join(v.name.split("/")[2:])] = v - _decoder_variables = {} - for v in _tf_model.decoder.trainable_variables + _tf_model.decoder.non_trainable_variables: - _decoder_variables["/".join(v.name.split("/")[2:])] = v - - # assign weight values to `encoder` and `decoder` from `_tf_model` - for name, v in encoder_variables.items(): - v.assign(_encoder_variables[name]) - for name, v in decoder_variables.items(): - v.assign(_decoder_variables[name]) - - tf_model = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder) - - # Deal with `enc_to_dec_proj` - if hasattr(_tf_model, "enc_to_dec_proj"): - tf_model(tf_model.dummy_inputs) - tf_model.enc_to_dec_proj.kernel.assign(_tf_model.enc_to_dec_proj.kernel) - tf_model.enc_to_dec_proj.bias.assign(_tf_model.enc_to_dec_proj.bias) - - with tempfile.TemporaryDirectory() as tmpdirname: - encoder_dir = os.path.join(tmpdirname, "encoder") - decoder_dir = os.path.join(tmpdirname, "decoder") - tf_model.encoder.save_pretrained(encoder_dir) - tf_model.decoder.save_pretrained(decoder_dir) - - if hasattr(tf_model, "enc_to_dec_proj"): - enc_to_dec_proj_weight = torch.transpose( - torch.from_numpy(tf_model.enc_to_dec_proj.kernel.numpy()), 1, 0 - ) - enc_to_dec_proj_bias = torch.from_numpy(tf_model.enc_to_dec_proj.bias.numpy()) - - del _tf_model - del tf_model - gc.collect() - - attn_implementation = kwargs.get("attn_implementation") - kwargs_encoder_decoder = {} - if attn_implementation: - kwargs_encoder_decoder = { - "encoder_attn_implementation": attn_implementation, - "decoder_attn_implementation": attn_implementation, - } - - model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( - encoder_dir, - decoder_dir, - encoder_from_tf=True, - decoder_from_tf=True, - **kwargs_encoder_decoder, - ) - # This is only for copying some specific attributes of this particular model. - model.config = config - - if hasattr(model, "enc_to_dec_proj"): - model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous() - model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous() - - return model - - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - @classmethod def from_encoder_decoder_pretrained( cls, @@ -306,10 +179,6 @@ def from_encoder_decoder_pretrained( example is `google/vit-base-patch16-224-in21k`. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In - this case, `from_tf` should be set to `True` and a configuration object should be provided as - `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a - PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): Information necessary to initiate the text decoder. Can be either: @@ -317,10 +186,6 @@ def from_encoder_decoder_pretrained( - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In - this case, `from_tf` should be set to `True` and a configuration object should be provided as - `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a - PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. model_args (remaining positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 30af924611ae..a029bb32df03 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -584,10 +584,9 @@ def pipeline( - `"zero-shot-audio-classification"`: will return a [`ZeroShotAudioClassificationPipeline`]. - `"zero-shot-object-detection"`: will return a [`ZeroShotObjectDetectionPipeline`]. - model (`str` or [`PreTrainedModel`] or [`TFPreTrainedModel`], *optional*): + model (`str` or [`PreTrainedModel`], *optional*): The model that will be used by the pipeline to make predictions. This can be a model identifier or an - actual instance of a pretrained model inheriting from [`PreTrainedModel`] (for PyTorch) or - [`TFPreTrainedModel`] (for TensorFlow). + actual instance of a pretrained model inheriting from [`PreTrainedModel`]. If not provided, the default for the `task` will be loaded. config (`str` or [`PretrainedConfig`], *optional*): diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index b4d1b96ea87f..88071764aafb 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -134,9 +134,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) Arguments: - model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): + model ([`PreTrainedModel`]): The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from - [`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow. + [`PreTrainedModel`]. feature_extractor ([`SequenceFeatureExtractor`]): The feature extractor that will be used by the pipeline to encode waveform for the model. tokenizer ([`PreTrainedTokenizer`]): diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 09eaab16922f..0a555f385698 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -673,9 +673,9 @@ def build_pipeline_init_args( ) -> str: docstring = r""" Arguments: - model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): + model ([`PreTrainedModel`]): The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from - [`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.""" + [`PreTrainedModel`].""" if has_tokenizer: docstring += r""" tokenizer ([`PreTrainedTokenizer`]): diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index 5eeeb51cf389..c62aafc4fc68 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -22,7 +22,6 @@ logger = logging.get_logger(__name__) if TYPE_CHECKING: - from ..modeling_tf_utils import TFPreTrainedModel from ..modeling_utils import PreTrainedModel if is_tokenizers_available(): @@ -258,7 +257,7 @@ class QuestionAnsweringPipeline(ChunkPipeline): def __init__( self, - model: Union["PreTrainedModel", "TFPreTrainedModel"], + model: "PreTrainedModel", tokenizer: PreTrainedTokenizer, modelcard: Optional[ModelCard] = None, task: str = "", diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 42d8626ceaec..6b5bdd3e9c3c 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -295,7 +295,7 @@ def register_quantizer_fn(cls): return register_quantizer_fn -def get_hf_quantizer(config, quantization_config, dtype, from_tf, from_flax, device_map, weights_only, user_agent): +def get_hf_quantizer(config, quantization_config, dtype, device_map, weights_only, user_agent): pre_quantized = hasattr(config, "quantization_config") if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config): pre_quantized = False @@ -318,8 +318,6 @@ def get_hf_quantizer(config, quantization_config, dtype, from_tf, from_flax, dev if hf_quantizer is not None: hf_quantizer.validate_environment( dtype=dtype, - from_tf=from_tf, - from_flax=from_flax, device_map=device_map, weights_only=weights_only, ) diff --git a/src/transformers/quantizers/quantizer_bitnet.py b/src/transformers/quantizers/quantizer_bitnet.py index a57e732b9823..b8b7e1eb3bd0 100644 --- a/src/transformers/quantizers/quantizer_bitnet.py +++ b/src/transformers/quantizers/quantizer_bitnet.py @@ -50,12 +50,6 @@ def validate_environment(self, *args, **kwargs): if not is_accelerate_available(): raise ImportError("Loading a BitNet quantized model requires accelerate (`pip install accelerate`)") - if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): - raise ValueError( - "Loading ternary weights from tf/flax is currently not supported, please make" - " sure the weights are in PyTorch format." - ) - if not torch.cuda.is_available(): logger.warning_once( "You don't have a GPU available to load the model, the inference will be slow because of weight unpacking" diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 74879fa17ac4..b1fc580142eb 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -96,12 +96,6 @@ def validate_environment(self, *args, **kwargs): bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available() validate_bnb_backend_availability(raise_exception=True) - if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): - raise ValueError( - "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make" - " sure the weights are in PyTorch format." - ) - device_map = kwargs.get("device_map") if ( device_map is not None diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index 1d269765f57f..be044ac3b325 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -93,12 +93,6 @@ def validate_environment(self, *args, **kwargs): bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available() validate_bnb_backend_availability(raise_exception=True) - if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): - raise ValueError( - "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make" - " sure the weights are in PyTorch format." - ) - device_map = kwargs.get("device_map") if ( device_map is not None diff --git a/src/transformers/quantizers/quantizer_eetq.py b/src/transformers/quantizers/quantizer_eetq.py index 00a8117be9d2..8953f07ea859 100644 --- a/src/transformers/quantizers/quantizer_eetq.py +++ b/src/transformers/quantizers/quantizer_eetq.py @@ -70,12 +70,6 @@ def validate_environment(self, *args, **kwargs): if not is_accelerate_available(): raise ImportError("Loading an EETQ quantized model requires accelerate (`pip install accelerate`)") - if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): - raise ValueError( - "Converting into 8-bit weights from tf/flax weights is currently not supported, please make" - " sure the weights are in PyTorch format." - ) - if not torch.cuda.is_available(): raise RuntimeError("No GPU found. A GPU is needed for quantization.") diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index dc30221b590e..c2f1414eced3 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -38,12 +38,6 @@ def validate_environment(self, *args, **kwargs): if not is_accelerate_available(): raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)") - if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): - raise ValueError( - "Converting into FP8 weights from tf/flax weights is currently not supported, " - "please make sure the weights are in PyTorch format." - ) - if not (torch.cuda.is_available() or is_torch_xpu_available()): raise RuntimeError("No GPU or XPU found. A GPU or XPU is needed for FP8 quantization.") diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index fa1d276c6a1a..f5f60a6d9e4b 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -65,12 +65,6 @@ def validate_environment(self, *args, **kwargs): "A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`." ) - if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): - raise ValueError( - "Converting weights from tf/flax weights is currently not supported, please make" - " sure the weights are in PyTorch format." - ) - if self.dtype is None: if "dtype" in kwargs: self.dtype = kwargs["dtype"] diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 7af5315f844c..9d6e9569a9dc 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -514,12 +514,12 @@ def test_model_file_not_found(self): ): _ = AutoModel.from_pretrained("hf-internal-testing/config-no-model") - def test_model_from_tf_suggestion(self): - with self.assertRaisesRegex(EnvironmentError, "Use `from_tf=True` to load this model"): + def test_model_from_tf_error(self): + with self.assertRaisesRegex(EnvironmentError, "Can't load the model for"): _ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only") - def test_model_from_flax_suggestion(self): - with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"): + def test_model_from_flax_error(self): + with self.assertRaisesRegex(EnvironmentError, "Can't load the model for"): _ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") @unittest.skip("Failing on main") diff --git a/tests/pipelines/test_pipelines_audio_classification.py b/tests/pipelines/test_pipelines_audio_classification.py index 19c16b486514..e30fd976cc97 100644 --- a/tests/pipelines/test_pipelines_audio_classification.py +++ b/tests/pipelines/test_pipelines_audio_classification.py @@ -20,7 +20,6 @@ from transformers import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, - TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, is_torch_available, ) from transformers.pipelines import AudioClassificationPipeline, pipeline @@ -43,7 +42,6 @@ @is_pipeline_test class AudioClassificationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING - tf_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING _dataset = None @classmethod diff --git a/tests/pipelines/test_pipelines_feature_extraction.py b/tests/pipelines/test_pipelines_feature_extraction.py index f6cbbb2fea2e..a8321da15eb8 100644 --- a/tests/pipelines/test_pipelines_feature_extraction.py +++ b/tests/pipelines/test_pipelines_feature_extraction.py @@ -20,7 +20,6 @@ FEATURE_EXTRACTOR_MAPPING, IMAGE_PROCESSOR_MAPPING, MODEL_MAPPING, - TF_MODEL_MAPPING, FeatureExtractionPipeline, LxmertConfig, is_torch_available, @@ -36,7 +35,6 @@ @is_pipeline_test class FeatureExtractionPipelineTests(unittest.TestCase): model_mapping = MODEL_MAPPING - tf_model_mapping = TF_MODEL_MAPPING @require_torch def test_small_model_pt(self): diff --git a/tests/pipelines/test_pipelines_fill_mask.py b/tests/pipelines/test_pipelines_fill_mask.py index e1e88f63f7fe..af5ee8c4be0c 100644 --- a/tests/pipelines/test_pipelines_fill_mask.py +++ b/tests/pipelines/test_pipelines_fill_mask.py @@ -15,7 +15,7 @@ import gc import unittest -from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline +from transformers import MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline from transformers.pipelines import PipelineException from transformers.testing_utils import ( backend_empty_cache, @@ -34,7 +34,6 @@ @is_pipeline_test class FillMaskPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_MASKED_LM_MAPPING - tf_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING def tearDown(self): super().tearDown() diff --git a/tests/pipelines/test_pipelines_image_classification.py b/tests/pipelines/test_pipelines_image_classification.py index 04d2dc876b33..f8f8c908095c 100644 --- a/tests/pipelines/test_pipelines_image_classification.py +++ b/tests/pipelines/test_pipelines_image_classification.py @@ -19,7 +19,6 @@ from transformers import ( MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, - TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, PreTrainedTokenizerBase, is_torch_available, is_vision_available, @@ -30,7 +29,6 @@ is_pipeline_test, nested_simplify, require_torch, - require_torch_or_tf, require_vision, slow, ) @@ -52,11 +50,10 @@ def open(*args, **kwargs): @is_pipeline_test -@require_torch_or_tf +@require_torch @require_vision class ImageClassificationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING - tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING _dataset = None @classmethod diff --git a/tests/pipelines/test_pipelines_image_feature_extraction.py b/tests/pipelines/test_pipelines_image_feature_extraction.py index 8254517c35d8..2705e1385331 100644 --- a/tests/pipelines/test_pipelines_image_feature_extraction.py +++ b/tests/pipelines/test_pipelines_image_feature_extraction.py @@ -19,7 +19,6 @@ from transformers import ( MODEL_MAPPING, - TF_MODEL_MAPPING, TOKENIZER_MAPPING, ImageFeatureExtractionPipeline, is_torch_available, @@ -45,7 +44,6 @@ def prepare_img(): @is_pipeline_test class ImageFeatureExtractionPipelineTests(unittest.TestCase): model_mapping = MODEL_MAPPING - tf_model_mapping = TF_MODEL_MAPPING @require_torch def test_small_model_pt(self): diff --git a/tests/pipelines/test_pipelines_image_to_text.py b/tests/pipelines/test_pipelines_image_to_text.py index ee73a1dfb63b..bc8ac76548ea 100644 --- a/tests/pipelines/test_pipelines_image_to_text.py +++ b/tests/pipelines/test_pipelines_image_to_text.py @@ -16,7 +16,7 @@ import requests -from transformers import MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, is_vision_available +from transformers import MODEL_FOR_VISION_2_SEQ_MAPPING, is_vision_available from transformers.pipelines import ImageToTextPipeline, pipeline from transformers.testing_utils import ( is_pipeline_test, @@ -42,7 +42,6 @@ def open(*args, **kwargs): @require_vision class ImageToTextPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING - tf_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING def get_test_pipeline( self, diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index e6bd82a846bc..b87742ee03cf 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -18,7 +18,6 @@ from transformers import ( MODEL_FOR_QUESTION_ANSWERING_MAPPING, - TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, LxmertConfig, QuestionAnsweringPipeline, ) @@ -30,7 +29,6 @@ is_torch_available, nested_simplify, require_torch, - require_torch_or_tf, slow, ) @@ -48,14 +46,9 @@ @is_pipeline_test class QAPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING - tf_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if not hasattr(model_mapping, "is_dummy"): model_mapping = {config: model for config, model in model_mapping.items() if config.__name__ not in _TO_SKIP} - if not hasattr(tf_model_mapping, "is_dummy"): - tf_model_mapping = { - config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP - } def get_test_pipeline( self, @@ -416,7 +409,7 @@ def test_large_model_course(self): ) -@require_torch_or_tf +@require_torch class QuestionAnsweringArgumentHandlerTests(unittest.TestCase): def test_argument_handler(self): qa = QuestionAnsweringArgumentHandler() diff --git a/tests/pipelines/test_pipelines_summarization.py b/tests/pipelines/test_pipelines_summarization.py index 99c2faf72b54..e58d9264b89d 100644 --- a/tests/pipelines/test_pipelines_summarization.py +++ b/tests/pipelines/test_pipelines_summarization.py @@ -16,9 +16,7 @@ from transformers import ( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, SummarizationPipeline, - TFPreTrainedModel, pipeline, ) from transformers.testing_utils import is_pipeline_test, require_torch, slow, torch_device @@ -30,7 +28,6 @@ @is_pipeline_test class SummarizationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING - tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING def get_test_pipeline( self, @@ -78,17 +75,9 @@ def run_pipeline_test(self, summarizer, _): "ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values) ] if model.config.__class__.__name__ not in model_can_handle_longer_seq: - # Too long and exception is expected. - # For TF models, if the weights are initialized in GPU context, we won't get expected index error from - # the embedding layer. - if not ( - isinstance(model, TFPreTrainedModel) - and len(summarizer.model.trainable_weights) > 0 - and "GPU" in summarizer.model.trainable_weights[0].device - ): - if str(summarizer.device) == "cpu": - with self.assertRaises(Exception): - outputs = summarizer("This " * 1000) + if str(summarizer.device) == "cpu": + with self.assertRaises(Exception): + outputs = summarizer("This " * 1000) outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST) @require_torch diff --git a/tests/pipelines/test_pipelines_text2text_generation.py b/tests/pipelines/test_pipelines_text2text_generation.py index 5e3ed2b32f09..730f707237db 100644 --- a/tests/pipelines/test_pipelines_text2text_generation.py +++ b/tests/pipelines/test_pipelines_text2text_generation.py @@ -16,7 +16,6 @@ from transformers import ( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, Text2TextGenerationPipeline, pipeline, ) @@ -33,7 +32,6 @@ @is_pipeline_test class Text2TextGenerationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING - tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING def get_test_pipeline( self, diff --git a/tests/pipelines/test_pipelines_text_classification.py b/tests/pipelines/test_pipelines_text_classification.py index 6400ad039bcb..902cc59b0987 100644 --- a/tests/pipelines/test_pipelines_text_classification.py +++ b/tests/pipelines/test_pipelines_text_classification.py @@ -16,7 +16,6 @@ from transformers import ( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, - TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TextClassificationPipeline, pipeline, ) @@ -45,14 +44,9 @@ @is_pipeline_test class TextClassificationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING - tf_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING if not hasattr(model_mapping, "is_dummy"): model_mapping = {config: model for config, model in model_mapping.items() if config.__name__ not in _TO_SKIP} - if not hasattr(tf_model_mapping, "is_dummy"): - tf_model_mapping = { - config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP - } @require_torch def test_small_model_pt(self): diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index a408266036c3..456c8ea922d8 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -17,7 +17,6 @@ from transformers import ( MODEL_FOR_CAUSAL_LM_MAPPING, - TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, logging, pipeline, @@ -28,7 +27,6 @@ require_accelerate, require_torch, require_torch_accelerator, - require_torch_or_tf, torch_device, ) @@ -36,10 +34,9 @@ @is_pipeline_test -@require_torch_or_tf +@require_torch class TextGenerationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING - tf_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING @require_torch def test_small_model_pt(self): diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index be49e1c7bdc5..e435ce800fe0 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -26,7 +26,6 @@ is_pipeline_test, require_torch, require_torch_accelerator, - require_torch_or_tf, slow, torch_device, ) @@ -36,7 +35,7 @@ @is_pipeline_test -@require_torch_or_tf +@require_torch class TextToAudioPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING # for now only test text_to_waveform and not text_to_spectrogram diff --git a/tests/pipelines/test_pipelines_token_classification.py b/tests/pipelines/test_pipelines_token_classification.py index c5a4d3b89aa8..19e6d342805c 100644 --- a/tests/pipelines/test_pipelines_token_classification.py +++ b/tests/pipelines/test_pipelines_token_classification.py @@ -18,7 +18,6 @@ from transformers import ( MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, - TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, AutoModelForTokenClassification, AutoTokenizer, TokenClassificationPipeline, @@ -51,14 +50,9 @@ @is_pipeline_test class TokenClassificationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING - tf_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING if not hasattr(model_mapping, "is_dummy"): model_mapping = {config: model for config, model in model_mapping.items() if config.__name__ not in _TO_SKIP} - if not hasattr(tf_model_mapping, "is_dummy"): - tf_model_mapping = { - config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP - } def get_test_pipeline( self, diff --git a/tests/pipelines/test_pipelines_translation.py b/tests/pipelines/test_pipelines_translation.py index 0bb34ba5e635..3b95cce1f70d 100644 --- a/tests/pipelines/test_pipelines_translation.py +++ b/tests/pipelines/test_pipelines_translation.py @@ -18,7 +18,6 @@ from transformers import ( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MBart50TokenizerFast, MBartConfig, MBartForConditionalGeneration, @@ -33,7 +32,6 @@ @is_pipeline_test class TranslationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING - tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING def get_test_pipeline( self, diff --git a/tests/pipelines/test_pipelines_video_classification.py b/tests/pipelines/test_pipelines_video_classification.py index d92281fc890d..cea7a94c3910 100644 --- a/tests/pipelines/test_pipelines_video_classification.py +++ b/tests/pipelines/test_pipelines_video_classification.py @@ -24,7 +24,6 @@ nested_simplify, require_av, require_torch, - require_torch_or_tf, require_vision, ) @@ -32,7 +31,7 @@ @is_pipeline_test -@require_torch_or_tf +@require_torch @require_vision @require_av class VideoClassificationPipelineTests(unittest.TestCase): diff --git a/tests/pipelines/test_pipelines_zero_shot.py b/tests/pipelines/test_pipelines_zero_shot.py index ed26e911ee57..479854dba972 100644 --- a/tests/pipelines/test_pipelines_zero_shot.py +++ b/tests/pipelines/test_pipelines_zero_shot.py @@ -16,7 +16,6 @@ from transformers import ( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, - TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, Pipeline, ZeroShotClassificationPipeline, pipeline, @@ -43,14 +42,9 @@ @is_pipeline_test class ZeroShotClassificationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING - tf_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING if not hasattr(model_mapping, "is_dummy"): model_mapping = {config: model for config, model in model_mapping.items() if config.__name__ not in _TO_SKIP} - if not hasattr(tf_model_mapping, "is_dummy"): - tf_model_mapping = { - config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP - } def get_test_pipeline( self, diff --git a/tests/test_pipeline_mixin.py b/tests/test_pipeline_mixin.py index aae6ff9413f1..5002a3e9b946 100644 --- a/tests/test_pipeline_mixin.py +++ b/tests/test_pipeline_mixin.py @@ -56,7 +56,6 @@ require_pytesseract, require_timm, require_torch, - require_torch_or_tf, require_vision, ) from transformers.utils import direct_transformers_import, logging @@ -143,7 +142,6 @@ test = task_info["test"] task_info["mapping"] = { "pt": getattr(test, "model_mapping", None), - "tf": getattr(test, "tf_model_mapping", None), } @@ -171,7 +169,6 @@ class PipelineTesterMixin: model_tester = None pipeline_model_mapping = None - supported_frameworks = ["pt", "tf"] def run_task_tests(self, task, dtype="float32"): """Run pipeline tests for a specific `task` @@ -200,12 +197,6 @@ def run_task_tests(self, task, dtype="float32"): model_arch_name = model_architecture.__name__ model_type = model_architecture.config_class.model_type - # Get the canonical name - for _prefix in ["Flax", "TF"]: - if model_arch_name.startswith(_prefix): - model_arch_name = model_arch_name[len(_prefix) :] - break - if model_arch_name not in tiny_model_summary: continue @@ -562,7 +553,7 @@ def test_pipeline_fill_mask_fp16(self): self.run_task_tests(task="fill-mask", dtype="float16") @is_pipeline_test - @require_torch_or_tf + @require_torch @require_vision def test_pipeline_image_classification(self): self.run_task_tests(task="image-classification") @@ -698,7 +689,7 @@ def test_pipeline_text_classification_fp16(self): self.run_task_tests(task="text-classification", dtype="float16") @is_pipeline_test - @require_torch_or_tf + @require_torch def test_pipeline_text_generation(self): self.run_task_tests(task="text-generation") @@ -736,7 +727,7 @@ def test_pipeline_translation_fp16(self): self.run_task_tests(task="translation", dtype="float16") @is_pipeline_test - @require_torch_or_tf + @require_torch @require_vision @require_av def test_pipeline_video_classification(self): diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index b0e6bcf2ce29..6f4c6457ec5f 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -67,7 +67,7 @@ if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel + from transformers import PretrainedConfig, PreTrainedModel def use_cache_if_possible(func): @@ -122,11 +122,11 @@ def filter_roberta_detectors(_, pretrained_name: str): def merge_model_tokenizer_mappings( - model_mapping: dict["PretrainedConfig", Union["PreTrainedModel", "TFPreTrainedModel"]], + model_mapping: dict["PretrainedConfig", "PreTrainedModel"], tokenizer_mapping: dict["PretrainedConfig", tuple["PreTrainedTokenizer", "PreTrainedTokenizerFast"]], ) -> dict[ Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"], - tuple["PretrainedConfig", Union["PreTrainedModel", "TFPreTrainedModel"]], + tuple["PretrainedConfig", "PreTrainedModel"], ]: configurations = list(model_mapping.keys()) model_tokenizer_mapping = OrderedDict([]) From 5c92286b65e24d9debe3d45334d1b1982829e21c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 9 Sep 2025 18:04:46 +0200 Subject: [PATCH 07/35] style --- src/transformers/pipelines/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 0a555f385698..e9ef235af087 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -798,7 +798,7 @@ class Pipeline(_ScikitCompat, PushToHubMixin): def __init__( self, - model: PreTrainedModel, + model: "PreTrainedModel", tokenizer: Optional[PreTrainedTokenizer] = None, feature_extractor: Optional[PreTrainedFeatureExtractor] = None, image_processor: Optional[BaseImageProcessor] = None, From 035442046d0228f07f4cc4c8fcb95b3e3a548f76 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 10 Sep 2025 15:17:53 +0200 Subject: [PATCH 08/35] fix configs --- src/transformers/models/gemma/configuration_gemma.py | 5 ----- src/transformers/models/gemma/modular_gemma.py | 5 ----- src/transformers/models/lxmert/configuration_lxmert.py | 4 ---- src/transformers/models/lxmert/modeling_lxmert.py | 3 +-- src/transformers/models/xlnet/configuration_xlnet.py | 4 ---- tests/models/lxmert/test_modeling_lxmert.py | 3 --- tests/models/xlnet/test_modeling_xlnet.py | 3 --- utils/check_config_attributes.py | 1 + 8 files changed, 2 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 363af5c3ffc4..58d6c3d08537 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -54,9 +54,6 @@ class GemmaConfig(PretrainedConfig): The attention head dimension. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The legacy activation function. It is overwritten by the `hidden_activation`. - hidden_activation (`str` or `function`, *optional*): - The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` - if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. max_position_embeddings (`int`, *optional*, defaults to 8192): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): @@ -117,7 +114,6 @@ def __init__( num_key_value_heads=16, head_dim=256, hidden_act="gelu_pytorch_tanh", - hidden_activation=None, max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-6, @@ -140,7 +136,6 @@ def __init__( self.head_dim = head_dim self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act - self.hidden_activation = hidden_activation self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 281fcd54fb7d..00dfb9edbcf7 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -79,9 +79,6 @@ class GemmaConfig(PretrainedConfig): The attention head dimension. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The legacy activation function. It is overwritten by the `hidden_activation`. - hidden_activation (`str` or `function`, *optional*): - The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` - if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. max_position_embeddings (`int`, *optional*, defaults to 8192): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): @@ -142,7 +139,6 @@ def __init__( num_key_value_heads=16, head_dim=256, hidden_act="gelu_pytorch_tanh", - hidden_activation=None, max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-6, @@ -165,7 +161,6 @@ def __init__( self.head_dim = head_dim self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act - self.hidden_activation = hidden_activation self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache diff --git a/src/transformers/models/lxmert/configuration_lxmert.py b/src/transformers/models/lxmert/configuration_lxmert.py index 18d3d2e60d7b..cba273e0f19a 100644 --- a/src/transformers/models/lxmert/configuration_lxmert.py +++ b/src/transformers/models/lxmert/configuration_lxmert.py @@ -66,8 +66,6 @@ class LxmertConfig(PretrainedConfig): The vocabulary size of the *token_type_ids* passed into [`BertModel`]. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - layer_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the layer normalization layers. l_layers (`int`, *optional*, defaults to 9): Number of hidden layers in the Transformer language encoder. x_layers (`int`, *optional*, defaults to 5): @@ -119,7 +117,6 @@ def __init__( max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, - layer_norm_eps=1e-12, l_layers=9, x_layers=5, r_layers=5, @@ -145,7 +142,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps self.num_qa_labels = num_qa_labels self.num_object_labels = num_object_labels self.num_attr_labels = num_attr_labels diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index 552508e76974..5b81ed662e8a 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -187,8 +187,7 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file + # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/xlnet/configuration_xlnet.py b/src/transformers/models/xlnet/configuration_xlnet.py index 4a7238eb4c09..d32f05c875bb 100644 --- a/src/transformers/models/xlnet/configuration_xlnet.py +++ b/src/transformers/models/xlnet/configuration_xlnet.py @@ -49,8 +49,6 @@ class XLNetConfig(PretrainedConfig): ff_activation (`str` or `Callable`, *optional*, defaults to `"gelu"`): The non-linear activation function (function or string) in the If string, `"gelu"`, `"relu"`, `"silu"` and `"gelu_new"` are supported. - untie_r (`bool`, *optional*, defaults to `True`): - Whether or not to untie relative position biases attn_type (`str`, *optional*, defaults to `"bi"`): The attention type used by the model. Set `"bi"` for XLNet, `"uni"` for Transformer-XL. initializer_range (`float`, *optional*, defaults to 0.02): @@ -150,7 +148,6 @@ def __init__( n_head=16, d_inner=4096, ff_activation="gelu", - untie_r=True, attn_type="bi", initializer_range=0.02, layer_norm_eps=1e-12, @@ -188,7 +185,6 @@ def __init__( self.d_head = d_model // n_head self.ff_activation = ff_activation self.d_inner = d_inner - self.untie_r = untie_r self.attn_type = attn_type self.initializer_range = initializer_range diff --git a/tests/models/lxmert/test_modeling_lxmert.py b/tests/models/lxmert/test_modeling_lxmert.py index cc7f863bd858..033fcc0605d6 100644 --- a/tests/models/lxmert/test_modeling_lxmert.py +++ b/tests/models/lxmert/test_modeling_lxmert.py @@ -54,7 +54,6 @@ def __init__( max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, - layer_norm_eps=1e-12, pad_token_id=0, num_qa_labels=30, num_object_labels=16, @@ -94,7 +93,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps self.pad_token_id = pad_token_id self.num_qa_labels = num_qa_labels self.num_object_labels = num_object_labels @@ -194,7 +192,6 @@ def get_config(self): max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, - layer_norm_eps=self.layer_norm_eps, pad_token_id=self.pad_token_id, num_qa_labels=self.num_qa_labels, num_object_labels=self.num_object_labels, diff --git a/tests/models/xlnet/test_modeling_xlnet.py b/tests/models/xlnet/test_modeling_xlnet.py index b8bed5c822af..ae0e2b9d56df 100644 --- a/tests/models/xlnet/test_modeling_xlnet.py +++ b/tests/models/xlnet/test_modeling_xlnet.py @@ -56,7 +56,6 @@ def __init__( d_inner=128, num_hidden_layers=2, type_sequence_label_size=2, - untie_r=True, bi_data=False, same_length=False, initializer_range=0.05, @@ -83,7 +82,6 @@ def __init__( self.d_inner = 128 self.num_hidden_layers = 5 self.type_sequence_label_size = 2 - self.untie_r = True self.bi_data = False self.same_length = False self.initializer_range = 0.05 @@ -152,7 +150,6 @@ def get_config(self): n_head=self.num_attention_heads, d_inner=self.d_inner, n_layer=self.num_hidden_layers, - untie_r=self.untie_r, mem_len=self.mem_len, clamp_len=self.clamp_len, same_length=self.same_length, diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index d3ca53a56076..06ea41d9ec2c 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -306,6 +306,7 @@ "SmolLM3Config": ["no_rope_layer_interval"], "Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"], # this is for use in `timm` "VaultGemmaConfig": ["tie_word_embeddings"], + "GemmaConfig": ["tie_word_embeddings"], } From ca1656948fe373be19eb025f0a48280a1d06810e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 10 Sep 2025 15:31:10 +0200 Subject: [PATCH 09/35] oups forgot conflict --- tests/models/siglip/test_tokenization_siglip.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/models/siglip/test_tokenization_siglip.py b/tests/models/siglip/test_tokenization_siglip.py index 6d54d7dad26a..843058c8a019 100644 --- a/tests/models/siglip/test_tokenization_siglip.py +++ b/tests/models/siglip/test_tokenization_siglip.py @@ -20,11 +20,6 @@ from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, SiglipTokenizer from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow -<<<<<<< HEAD -from transformers.utils import is_tf_available, is_torch_available -======= -from transformers.utils import cached_property ->>>>>>> 928d4415fa (more and more) from ...test_tokenization_common import TokenizerTesterMixin From 896e965f429b4153c69bfb5fa080dea4d51819ef Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 10 Sep 2025 16:58:19 +0200 Subject: [PATCH 10/35] continue --- README.md | 2 +- docker/transformers-all-latest-gpu/Dockerfile | 4 +- docker/transformers-gpu/Dockerfile | 1 - docker/transformers-past-gpu/Dockerfile | 2 +- .../transformers-pytorch-amd-gpu/Dockerfile | 3 - docker/transformers-pytorch-gpu/Dockerfile | 2 - docker/transformers-tensorflow-gpu/Dockerfile | 25 --- src/transformers/commands/convert.py | 165 ------------------ src/transformers/data/processors/utils.py | 2 +- .../feature_extraction_sequence_utils.py | 2 +- src/transformers/modeling_utils.py | 2 - .../models/deit/configuration_deit.py | 4 +- .../modeling_gptsan_japanese.py | 9 - .../models/dpt/configuration_dpt.py | 4 +- .../image_processing_efficientloftr.py | 3 +- .../models/ijepa/configuration_ijepa.py | 4 +- .../kosmos2_5/image_processing_kosmos2_5.py | 7 +- .../image_processing_kosmos2_5_fast.py | 3 - .../lightglue/image_processing_lightglue.py | 3 +- .../models/longt5/modeling_longt5.py | 10 -- .../models/luke/tokenization_luke.py | 2 +- .../models/mluke/tokenization_mluke.py | 2 +- .../mobilenet_v2/modeling_mobilenet_v2.py | 4 +- .../models/mobilevit/modeling_mobilevit.py | 4 +- .../mobilevitv2/modeling_mobilevitv2.py | 4 +- src/transformers/models/mt5/modeling_mt5.py | 10 -- .../pix2struct/image_processing_pix2struct.py | 8 +- .../models/pix2struct/modeling_pix2struct.py | 2 - .../models/pop2piano/modeling_pop2piano.py | 9 - .../superglue/image_processing_superglue.py | 3 +- .../superpoint/image_processing_superpoint.py | 3 +- .../modeling_switch_transformers.py | 11 -- src/transformers/models/t5/modeling_t5.py | 10 -- src/transformers/models/udop/modeling_udop.py | 9 - .../models/vit/configuration_vit.py | 4 +- .../models/whisper/tokenization_whisper.py | 7 +- .../whisper/tokenization_whisper_fast.py | 7 +- .../tokenization_mistral_common.py | 2 +- src/transformers/trainer_utils.py | 2 - .../utils/dummy_tensorflow_text_objects.py | 9 - tests/repo_utils/test_tests_fetcher.py | 47 +---- tests/sagemaker/conftest.py | 24 +-- .../test_multi_node_data_parallel.py | 7 - tests/sagemaker/test_single_node_gpu.py | 7 - utils/not_doctested.txt | 2 - utils/notification_service.py | 17 -- utils/past_ci_versions.py | 126 ------------- utils/print_env.py | 9 - utils/tests_fetcher.py | 34 ++-- utils/update_metadata.py | 26 +-- 50 files changed, 58 insertions(+), 610 deletions(-) delete mode 100644 docker/transformers-tensorflow-gpu/Dockerfile delete mode 100644 src/transformers/commands/convert.py delete mode 100644 src/transformers/utils/dummy_tensorflow_text_objects.py delete mode 100644 utils/past_ci_versions.py diff --git a/README.md b/README.md index 5d782bcea78e..0717343f9cff 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ Explore the [Hub](https://huggingface.com/) today to find a model and use Transf ## Installation -Transformers works with Python 3.9+ [PyTorch](https://pytorch.org/get-started/locally/) 2.1+, [TensorFlow](https://www.tensorflow.org/install/pip) 2.6+, and [Flax](https://flax.readthedocs.io/en/latest/) 0.4.1+. +Transformers works with Python 3.9+, and [PyTorch](https://pytorch.org/get-started/locally/) 2.1+. Create and activate a virtual environment with [venv](https://docs.python.org/3/library/venv.html) or [uv](https://docs.astral.sh/uv/), a fast Rust-based Python package and project manager. diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 64cd09b928a2..552e5697e96c 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -26,9 +26,7 @@ RUN git clone https://github.com/huggingface/transformers && cd transformers && # 1. Put several commands in a single `RUN` to avoid image/layer exporting issue. Could be revised in the future. # 2. Regarding `torch` part, We might need to specify proper versions for `torchvision` and `torchaudio`. # Currently, let's not bother to specify their versions explicitly (so installed with their latest release versions). -RUN python3 -m pip install --no-cache-dir -e ./transformers[dev,onnxruntime] && [ ${#PYTORCH} -gt 0 -a "$PYTORCH" != "pre" ] && VERSION='torch=='$PYTORCH'.*' || VERSION='torch'; echo "export VERSION='$VERSION'" >> ~/.profile && echo torch=$VERSION && [ "$PYTORCH" != "pre" ] && python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/$CUDA || python3 -m pip install --no-cache-dir -U --pre torch torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/nightly/$CUDA && python3 -m pip uninstall -y tensorflow tensorflow_text tensorflow_probability - -RUN python3 -m pip uninstall -y flax jax +RUN python3 -m pip install --no-cache-dir -e ./transformers[dev,onnxruntime] && [ ${#PYTORCH} -gt 0 -a "$PYTORCH" != "pre" ] && VERSION='torch=='$PYTORCH'.*' || VERSION='torch'; echo "export VERSION='$VERSION'" >> ~/.profile && echo torch=$VERSION && [ "$PYTORCH" != "pre" ] && python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/$CUDA || python3 -m pip install --no-cache-dir -U --pre torch torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/nightly/$CUDA RUN python3 -m pip install --no-cache-dir -U timm diff --git a/docker/transformers-gpu/Dockerfile b/docker/transformers-gpu/Dockerfile index 30de59d8b50a..e78e52df4897 100644 --- a/docker/transformers-gpu/Dockerfile +++ b/docker/transformers-gpu/Dockerfile @@ -15,7 +15,6 @@ RUN apt update && \ RUN python3 -m pip install --no-cache-dir --upgrade pip && \ python3 -m pip install --no-cache-dir \ jupyter \ - tensorflow \ torch RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/kernels@main#egg=kernels diff --git a/docker/transformers-past-gpu/Dockerfile b/docker/transformers-past-gpu/Dockerfile index a872231d0418..34bfbb19cef5 100644 --- a/docker/transformers-past-gpu/Dockerfile +++ b/docker/transformers-past-gpu/Dockerfile @@ -27,7 +27,7 @@ ARG VERSION RUN [ "$VERSION" != "1.10" ] && python3 -m pip install -U setuptools || python3 -m pip install -U "setuptools<=59.5" # Remove all frameworks -RUN python3 -m pip uninstall -y torch torchvision torchaudio tensorflow jax flax +RUN python3 -m pip uninstall -y torch torchvision torchaudio # Get the libraries and their versions to install, and write installation command to `~/.profile`. RUN python3 ./transformers/utils/past_ci_versions.py --framework $FRAMEWORK --version $VERSION diff --git a/docker/transformers-pytorch-amd-gpu/Dockerfile b/docker/transformers-pytorch-amd-gpu/Dockerfile index 37542ffb8943..4191021d5bf2 100644 --- a/docker/transformers-pytorch-amd-gpu/Dockerfile +++ b/docker/transformers-pytorch-amd-gpu/Dockerfile @@ -23,9 +23,6 @@ RUN git clone https://github.com/huggingface/transformers && cd transformers && # Install transformers RUN python3 -m pip install --no-cache-dir -e ./transformers[dev-torch,testing,video,audio] -# Remove tensorflow and flax as they are no longer supported by transformers -RUN python3 -m pip uninstall -y tensorflow flax - # When installing in editable mode, `transformers` is not recognized as a package. # this line must be added in order for python to be aware of transformers. RUN cd transformers && python3 setup.py develop diff --git a/docker/transformers-pytorch-gpu/Dockerfile b/docker/transformers-pytorch-gpu/Dockerfile index 5909ac436525..96fdba4b8d2d 100644 --- a/docker/transformers-pytorch-gpu/Dockerfile +++ b/docker/transformers-pytorch-gpu/Dockerfile @@ -25,8 +25,6 @@ RUN [ ${#PYTORCH} -gt 0 ] && VERSION='torch=='$PYTORCH'.*' || VERSION='torch'; RUN [ ${#TORCH_VISION} -gt 0 ] && VERSION='torchvision=='TORCH_VISION'.*' || VERSION='torchvision'; python3 -m pip install --no-cache-dir -U $VERSION --extra-index-url https://download.pytorch.org/whl/$CUDA RUN [ ${#TORCH_AUDIO} -gt 0 ] && VERSION='torchaudio=='TORCH_AUDIO'.*' || VERSION='torchaudio'; python3 -m pip install --no-cache-dir -U $VERSION --extra-index-url https://download.pytorch.org/whl/$CUDA -RUN python3 -m pip uninstall -y tensorflow flax - RUN python3 -m pip install --no-cache-dir git+https://github.com/facebookresearch/detectron2.git pytesseract RUN python3 -m pip install -U "itsdangerous<2.1.0" diff --git a/docker/transformers-tensorflow-gpu/Dockerfile b/docker/transformers-tensorflow-gpu/Dockerfile deleted file mode 100644 index 378491a6c600..000000000000 --- a/docker/transformers-tensorflow-gpu/Dockerfile +++ /dev/null @@ -1,25 +0,0 @@ -FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04 -LABEL maintainer="Hugging Face" - -ARG DEBIAN_FRONTEND=noninteractive - -RUN apt update -RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg -RUN python3 -m pip install --no-cache-dir --upgrade pip - -ARG REF=main -RUN git clone https://github.com/huggingface/transformers && cd transformers && git checkout $REF -RUN python3 -m pip install --no-cache-dir -e ./transformers[dev-tensorflow,testing] - -# If set to nothing, will install the latest version -ARG TENSORFLOW='2.13' - -RUN [ ${#TENSORFLOW} -gt 0 ] && VERSION='tensorflow=='$TENSORFLOW'.*' || VERSION='tensorflow'; python3 -m pip install --no-cache-dir -U $VERSION -RUN python3 -m pip uninstall -y torch flax -RUN python3 -m pip install -U "itsdangerous<2.1.0" - -RUN python3 -m pip install --no-cache-dir -U "tensorflow_probability<0.22" - -# When installing in editable mode, `transformers` is not recognized as a package. -# this line must be added in order for python to be aware of transformers. -RUN cd transformers && python3 setup.py develop diff --git a/src/transformers/commands/convert.py b/src/transformers/commands/convert.py deleted file mode 100644 index 220d1d44b1aa..000000000000 --- a/src/transformers/commands/convert.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from argparse import ArgumentParser, Namespace - -from ..utils import logging -from . import BaseTransformersCLICommand - - -def convert_command_factory(args: Namespace): - """ - Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint. - - Returns: ServeCommand - """ - return ConvertCommand( - args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name - ) - - -IMPORT_ERROR_MESSAGE = """ -transformers can only be used from the commandline to convert TensorFlow models in PyTorch, In that case, it requires -TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions. -""" - - -class ConvertCommand(BaseTransformersCLICommand): - @staticmethod - def register_subcommand(parser: ArgumentParser): - """ - Register this command to argparse so it's available for the transformer-cli - - Args: - parser: Root parser to register command-specific arguments - """ - train_parser = parser.add_parser( - "convert", - help="CLI tool to run convert model from original author checkpoints to Transformers PyTorch checkpoints.", - ) - train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.") - train_parser.add_argument( - "--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder." - ) - train_parser.add_argument( - "--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch saved model output." - ) - train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.") - train_parser.add_argument( - "--finetuning_task_name", - type=str, - default=None, - help="Optional fine-tuning task name if the TF model was a finetuned model.", - ) - train_parser.set_defaults(func=convert_command_factory) - - def __init__( - self, - model_type: str, - tf_checkpoint: str, - pytorch_dump_output: str, - config: str, - finetuning_task_name: str, - *args, - ): - self._logger = logging.get_logger("transformers/converting") - - self._logger.info(f"Loading model {model_type}") - self._model_type = model_type - self._tf_checkpoint = tf_checkpoint - self._pytorch_dump_output = pytorch_dump_output - self._config = config - self._finetuning_task_name = finetuning_task_name - - def run(self): - if self._model_type == "albert": - try: - from ..models.albert.convert_albert_original_tf_checkpoint_to_pytorch import ( - convert_tf_checkpoint_to_pytorch, - ) - except ImportError: - raise ImportError(IMPORT_ERROR_MESSAGE) - - convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) - elif self._model_type == "bert": - try: - from ..models.bert.convert_bert_original_tf_checkpoint_to_pytorch import ( - convert_tf_checkpoint_to_pytorch, - ) - except ImportError: - raise ImportError(IMPORT_ERROR_MESSAGE) - - convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) - elif self._model_type == "funnel": - try: - from ..models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import ( - convert_tf_checkpoint_to_pytorch, - ) - except ImportError: - raise ImportError(IMPORT_ERROR_MESSAGE) - - convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) - elif self._model_type == "t5": - try: - from ..models.t5.convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch - except ImportError: - raise ImportError(IMPORT_ERROR_MESSAGE) - - convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) - elif self._model_type == "gpt": - from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import ( - convert_openai_checkpoint_to_pytorch, - ) - - convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) - elif self._model_type == "gpt2": - try: - from ..models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import ( - convert_gpt2_checkpoint_to_pytorch, - ) - except ImportError: - raise ImportError(IMPORT_ERROR_MESSAGE) - - convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) - elif self._model_type == "xlnet": - try: - from ..models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import ( - convert_xlnet_checkpoint_to_pytorch, - ) - except ImportError: - raise ImportError(IMPORT_ERROR_MESSAGE) - - convert_xlnet_checkpoint_to_pytorch( - self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name - ) - elif self._model_type == "xlm": - from ..models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import ( - convert_xlm_checkpoint_to_pytorch, - ) - - convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output) - elif self._model_type == "lxmert": - from ..models.lxmert.convert_lxmert_original_tf_checkpoint_to_pytorch import ( - convert_lxmert_checkpoint_to_pytorch, - ) - - convert_lxmert_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output) - elif self._model_type == "rembert": - from ..models.rembert.convert_rembert_tf_checkpoint_to_pytorch import ( - convert_rembert_tf_checkpoint_to_pytorch, - ) - - convert_rembert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) - else: - raise ValueError("--model_type should be selected in the list [bert, gpt, gpt2, t5, xlnet, xlm, lxmert]") diff --git a/src/transformers/data/processors/utils.py b/src/transformers/data/processors/utils.py index c3db333ce594..63be55b558f9 100644 --- a/src/transformers/data/processors/utils.py +++ b/src/transformers/data/processors/utils.py @@ -82,7 +82,7 @@ class DataProcessor: def get_example_from_tensor_dict(self, tensor_dict): """ - Gets an example from a dict with tensorflow tensors. + Gets an example from a dict. Args: tensor_dict: Keys and values should match the corresponding Glue diff --git a/src/transformers/feature_extraction_sequence_utils.py b/src/transformers/feature_extraction_sequence_utils.py index b5c40ca44f1b..f4a0fe30441d 100644 --- a/src/transformers/feature_extraction_sequence_utils.py +++ b/src/transformers/feature_extraction_sequence_utils.py @@ -165,7 +165,7 @@ def pad( else: raise ValueError( f"type of {first_element} unknown: {type(first_element)}. " - "Should be one of a python, numpy, pytorch or tensorflow object." + "Should be one of a python, numpy, or pytorch object." ) for key, value in processed_features.items(): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4be12568ba43..a116a30b6056 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1615,8 +1615,6 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: if encoder_attention_mask.dim() == 2: encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow - # /transformer/transformer_layers.py#L270 # encoder_extended_attention_mask = (encoder_extended_attention_mask == # encoder_extended_attention_mask.transpose(-1, -2)) encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility diff --git a/src/transformers/models/deit/configuration_deit.py b/src/transformers/models/deit/configuration_deit.py index 7a321ebe293e..8909fe5aff66 100644 --- a/src/transformers/models/deit/configuration_deit.py +++ b/src/transformers/models/deit/configuration_deit.py @@ -72,9 +72,7 @@ class DeiTConfig(PretrainedConfig): pooler_output_size (`int`, *optional*): Dimensionality of the pooler layer. If None, defaults to `hidden_size`. pooler_act (`str`, *optional*, defaults to `"tanh"`): - The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and - Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are - supported for Tensorflow. + The activation function to be used by the pooler. Example: diff --git a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py index 0e2d27f03bac..eb5b030bbf82 100644 --- a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py @@ -706,22 +706,15 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, GPTSanJapaneseModel): - # Mesh TensorFlow embeddings initialization - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.embed_tokens.weight.data.normal_(mean=0.0, std=factor * 1.0) module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "extra_position_embeddings") and module.extra_position_embeddings is not None: module.extra_position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, (GPTSanJapaneseModel, GPTSanJapaneseForConditionalGeneration)): - # Mesh TensorFlow embeddings initialization - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.final_logits_bias.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, GPTSanJapaneseDenseActDense): - # Mesh TensorFlow FF initialization - # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 - # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: module.wi.bias.data.zero_() @@ -738,8 +731,6 @@ def _init_weights(self, module): module.q_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) elif isinstance(module, GPTSanJapaneseSparseMLP): - # Mesh TensorFlow attention initialization to avoid scaling before softmax - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_model n_heads = self.config.num_heads diff --git a/src/transformers/models/dpt/configuration_dpt.py b/src/transformers/models/dpt/configuration_dpt.py index 70e46f232022..d0263630b075 100644 --- a/src/transformers/models/dpt/configuration_dpt.py +++ b/src/transformers/models/dpt/configuration_dpt.py @@ -122,9 +122,7 @@ class DPTConfig(PretrainedConfig): pooler_output_size (`int`, *optional*): Dimensionality of the pooler layer. If None, defaults to `hidden_size`. pooler_act (`str`, *optional*, defaults to `"tanh"`): - The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and - Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are - supported for Tensorflow. + The activation function to be used by the pooler. Example: diff --git a/src/transformers/models/efficientloftr/image_processing_efficientloftr.py b/src/transformers/models/efficientloftr/image_processing_efficientloftr.py index 58ce0e96f5b8..2146ea8b39eb 100644 --- a/src/transformers/models/efficientloftr/image_processing_efficientloftr.py +++ b/src/transformers/models/efficientloftr/image_processing_efficientloftr.py @@ -70,8 +70,7 @@ def convert_to_grayscale( input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> ImageInput: """ - Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch - and tensorflow grayscale conversion + Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each channel, because of an issue that is discussed in : diff --git a/src/transformers/models/ijepa/configuration_ijepa.py b/src/transformers/models/ijepa/configuration_ijepa.py index 5f528adad0d5..084a7d8f3d94 100644 --- a/src/transformers/models/ijepa/configuration_ijepa.py +++ b/src/transformers/models/ijepa/configuration_ijepa.py @@ -59,9 +59,7 @@ class IJepaConfig(PretrainedConfig): pooler_output_size (`int`, *optional*): Dimensionality of the pooler layer. If None, defaults to `hidden_size`. pooler_act (`str`, *optional*, defaults to `"tanh"`): - The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and - Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are - supported for Tensorflow. + The activation function to be used by the pooler. Example: diff --git a/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py index d603575ef32d..18c087a4b368 100644 --- a/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py +++ b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py @@ -209,9 +209,6 @@ def normalize( """ Normalize an image. image = (image - image_mean) / image_std. - The image std is to mimic the tensorflow implementation of the `per_image_standardization`: - https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization - Args: image (`np.ndarray`): Image to normalize. @@ -253,9 +250,7 @@ def preprocess( """ Preprocess an image or batch of images. The processor first computes the maximum possible number of aspect-ratio preserving patches of size `patch_size` that can be extracted from the image. It then pads the - image with zeros to make the image respect the constraint of `max_patches`. Before extracting the patches the - images are standardized following the tensorflow implementation of `per_image_standardization` - (https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization). + image with zeros to make the image respect the constraint of `max_patches`. Args: diff --git a/src/transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py index c6d8b1b1edf5..c539288d9913 100644 --- a/src/transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +++ b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py @@ -105,9 +105,6 @@ def normalize( """ Normalize an image. image = (image - image_mean) / image_std. - The image std is to mimic the tensorflow implementation of the `per_image_standardization`: - https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization - Args: image (`torch.Tensor`): Image to normalize. diff --git a/src/transformers/models/lightglue/image_processing_lightglue.py b/src/transformers/models/lightglue/image_processing_lightglue.py index 400475b76c77..ce925ea173dd 100644 --- a/src/transformers/models/lightglue/image_processing_lightglue.py +++ b/src/transformers/models/lightglue/image_processing_lightglue.py @@ -73,8 +73,7 @@ def convert_to_grayscale( input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> ImageInput: """ - Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch - and tensorflow grayscale conversion + Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each channel, because of an issue that is discussed in : diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 4e84a1550349..70eec28d89f4 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1276,15 +1276,10 @@ def _init_weights(self, module): if isinstance(module, LongT5LayerNorm): module.weight.data.fill_(factor * 1.0) elif isinstance(module, (LongT5Model, LongT5ForConditionalGeneration, LongT5EncoderModel)): - # Mesh TensorFlow embeddings initialization - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, LongT5DenseActDense): - # Mesh TensorFlow FF initialization - # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 - # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: module.wi.bias.data.zero_() @@ -1302,8 +1297,6 @@ def _init_weights(self, module): if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, (LongT5Attention, LongT5LocalAttention, LongT5TransientGlobalAttention)): - # Mesh TensorFlow attention initialization to avoid scaling before softmax - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads @@ -2072,8 +2065,6 @@ def forward( sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) @@ -2084,7 +2075,6 @@ def forward( labels = labels.to(lm_logits.device) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py index 6838b9c5cb75..fc95ed11b079 100644 --- a/src/transformers/models/luke/tokenization_luke.py +++ b/src/transformers/models/luke/tokenization_luke.py @@ -1486,7 +1486,7 @@ def pad( else: raise ValueError( f"type of {first_element} unknown: {type(first_element)}. " - "Should be one of a python, numpy, pytorch or tensorflow object." + "Should be one of a python, numpy, or pytorch object." ) for key, value in encoded_inputs.items(): diff --git a/src/transformers/models/mluke/tokenization_mluke.py b/src/transformers/models/mluke/tokenization_mluke.py index 5c2e8c806da8..3d7a210d816a 100644 --- a/src/transformers/models/mluke/tokenization_mluke.py +++ b/src/transformers/models/mluke/tokenization_mluke.py @@ -1324,7 +1324,7 @@ def pad( else: raise ValueError( f"type of {first_element} unknown: {type(first_element)}. " - "Should be one of a python, numpy, pytorch or tensorflow object." + "Should be one of a python, numpy, or pytorch object." ) for key, value in encoded_inputs.items(): diff --git a/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py index d5c5f2e7fdb9..2d30da8f756d 100755 --- a/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py @@ -36,9 +36,7 @@ def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int: """ - Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the - original TensorFlow repo. It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + Ensure that all layers have a channel count that is divisible by `divisor`. """ if min_value is None: min_value = divisor diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 415c33a7cb85..b74987040fc7 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -42,9 +42,7 @@ def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int: """ - Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the - original TensorFlow repo. It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + Ensure that all layers have a channel count that is divisible by `divisor`. """ if min_value is None: min_value = divisor diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index 4e0e972a648a..0e9e616c52cd 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -41,9 +41,7 @@ # Copied from transformers.models.mobilevit.modeling_mobilevit.make_divisible def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int: """ - Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the - original TensorFlow repo. It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + Ensure that all layers have a channel count that is divisible by `divisor`. """ if min_value is None: min_value = divisor diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 69124c13a115..451db734a3a5 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -680,8 +680,6 @@ def _init_weights(self, module): module, (MT5Model, MT5ForConditionalGeneration, MT5EncoderModel, MT5ForQuestionAnswering), ): - # Mesh TensorFlow embeddings initialization - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) @@ -700,9 +698,6 @@ def _init_weights(self, module): if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: module.out_proj.bias.data.zero_() elif isinstance(module, MT5DenseActDense): - # Mesh TensorFlow FF initialization - # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 - # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: module.wi.bias.data.zero_() @@ -720,8 +715,6 @@ def _init_weights(self, module): if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, MT5Attention): - # Mesh TensorFlow attention initialization to avoid scaling before softmax - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads @@ -1701,8 +1694,6 @@ def forward( sequence_output = sequence_output.to(self.lm_head.weight.device) if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) @@ -1713,7 +1704,6 @@ def forward( # move labels to correct device to enable PP labels = labels.to(lm_logits.device) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs diff --git a/src/transformers/models/pix2struct/image_processing_pix2struct.py b/src/transformers/models/pix2struct/image_processing_pix2struct.py index 94ae65777692..f1cc2ba2068b 100644 --- a/src/transformers/models/pix2struct/image_processing_pix2struct.py +++ b/src/transformers/models/pix2struct/image_processing_pix2struct.py @@ -316,9 +316,6 @@ def normalize( """ Normalize an image. image = (image - image_mean) / image_std. - The image std is to mimic the tensorflow implementation of the `per_image_standardization`: - https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization - Args: image (`np.ndarray`): Image to normalize. @@ -361,10 +358,7 @@ def preprocess( """ Preprocess an image or batch of images. The processor first computes the maximum possible number of aspect-ratio preserving patches of size `patch_size` that can be extracted from the image. It then pads the - image with zeros to make the image respect the constraint of `max_patches`. Before extracting the patches the - images are standardized following the tensorflow implementation of `per_image_standardization` - (https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization). - + image with zeros to make the image respect the constraint of `max_patches`. Args: images (`ImageInput`): diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 463fec98256f..c79e937fc28a 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -387,8 +387,6 @@ def _init_weights(self, module): if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, Pix2StructTextAttention): - # Mesh TensorFlow attention initialization to avoid scaling before softmax - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 hidden_size = ( self.config.text_config.hidden_size if isinstance(self.config, Pix2StructConfig) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index ea6d3a5eea9e..74d9ebc97dc2 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -595,15 +595,10 @@ def _init_weights(self, module): elif isinstance(module, Pop2PianoConcatEmbeddingToMel): module.embedding.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, Pop2PianoForConditionalGeneration): - # Mesh TensorFlow embeddings initialization - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, Pop2PianoDenseActDense): - # Mesh TensorFlow FF initialization - # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 - # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: module.wi.bias.data.zero_() @@ -621,8 +616,6 @@ def _init_weights(self, module): if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, Pop2PianoAttention): - # Mesh TensorFlow attention initialization to avoid scaling before softmax - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads @@ -1207,8 +1200,6 @@ def forward( sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) diff --git a/src/transformers/models/superglue/image_processing_superglue.py b/src/transformers/models/superglue/image_processing_superglue.py index bde3355d78ed..f9192ac82df1 100644 --- a/src/transformers/models/superglue/image_processing_superglue.py +++ b/src/transformers/models/superglue/image_processing_superglue.py @@ -73,8 +73,7 @@ def convert_to_grayscale( input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> ImageInput: """ - Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch - and tensorflow grayscale conversion + Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each channel, because of an issue that is discussed in : diff --git a/src/transformers/models/superpoint/image_processing_superpoint.py b/src/transformers/models/superpoint/image_processing_superpoint.py index 4c895b035feb..f5da3335078e 100644 --- a/src/transformers/models/superpoint/image_processing_superpoint.py +++ b/src/transformers/models/superpoint/image_processing_superpoint.py @@ -64,8 +64,7 @@ def convert_to_grayscale( input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> ImageInput: """ - Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch - and tensorflow grayscale conversion + Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each channel, because of an issue that is discussed in : diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 0c2844250968..a2e3f8fb4c10 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -794,15 +794,10 @@ def _init_weights(self, module): module, (SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel), ): - # Mesh TensorFlow embeddings initialization - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, SwitchTransformersDenseActDense): - # Mesh TensorFlow FF initialization - # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 - # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: module.wi.bias.data.zero_() @@ -810,8 +805,6 @@ def _init_weights(self, module): if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, SwitchTransformersAttention): - # Mesh TensorFlow attention initialization to avoid scaling before softmax - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads @@ -822,8 +815,6 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) elif isinstance(module, SwitchTransformersSparseMLP): - # Mesh TensorFlow attention initialization to avoid scaling before softmax - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads @@ -1630,8 +1621,6 @@ def forward( sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index d5f09575e66d..7cd5e6394afd 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -680,8 +680,6 @@ def _init_weights(self, module): module, (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering), ): - # Mesh TensorFlow embeddings initialization - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) @@ -700,9 +698,6 @@ def _init_weights(self, module): if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: module.out_proj.bias.data.zero_() elif isinstance(module, T5DenseActDense): - # Mesh TensorFlow FF initialization - # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 - # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: module.wi.bias.data.zero_() @@ -720,8 +715,6 @@ def _init_weights(self, module): if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, T5Attention): - # Mesh TensorFlow attention initialization to avoid scaling before softmax - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads @@ -1664,8 +1657,6 @@ def forward( sequence_output = sequence_output.to(self.lm_head.weight.device) if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) @@ -1676,7 +1667,6 @@ def forward( # move labels to correct device to enable PP labels = labels.to(lm_logits.device) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 668ec6bfec3b..deacc075c72b 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -280,16 +280,11 @@ def _init_weights(self, module): d_model = self.config.d_model module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) elif isinstance(module, UdopModel): - # Mesh TensorFlow embeddings initialization - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, UdopForConditionalGeneration): if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, UdopDenseActDense): - # Mesh TensorFlow FF initialization - # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 - # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: module.wi.bias.data.zero_() @@ -307,8 +302,6 @@ def _init_weights(self, module): if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, UdopAttention): - # Mesh TensorFlow attention initialization to avoid scaling before softmax - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads @@ -1852,8 +1845,6 @@ def forward( sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.config.d_model**-0.5) lm_logits = self.lm_head(sequence_output) diff --git a/src/transformers/models/vit/configuration_vit.py b/src/transformers/models/vit/configuration_vit.py index ead272d0086d..7d69cdf51946 100644 --- a/src/transformers/models/vit/configuration_vit.py +++ b/src/transformers/models/vit/configuration_vit.py @@ -71,9 +71,7 @@ class ViTConfig(PretrainedConfig): pooler_output_size (`int`, *optional*): Dimensionality of the pooler layer. If None, defaults to `hidden_size`. pooler_act (`str`, *optional*, defaults to `"tanh"`): - The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and - Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are - supported for Tensorflow. + The activation function to be used by the pooler. Example: diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 4147f14d86bd..5e0d04acef1f 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -898,12 +898,7 @@ def _strip_prompt(self, token_ids: list[int], prompt_token_id: int, decoder_star def _convert_to_list(token_ids): # convert type to ndarray if necessary if hasattr(token_ids, "numpy"): - if "torch" in str(type(token_ids)): - token_ids = token_ids.cpu().numpy() - elif "tensorflow" in str(type(token_ids)): - token_ids = token_ids.numpy() - elif "jaxlib" in str(type(token_ids)): - token_ids = token_ids.tolist() + token_ids = token_ids.cpu().numpy() # now the token ids are either a numpy array, or a list of lists if isinstance(token_ids, np.ndarray): token_ids = token_ids.tolist() diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 07f4fdfcb002..235b356edab4 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -624,12 +624,7 @@ def _strip_prompt(self, token_ids: list[int], prompt_token_id: int, decoder_star def _convert_to_list(token_ids): # convert type to ndarray if necessary if hasattr(token_ids, "numpy"): - if "torch" in str(type(token_ids)): - token_ids = token_ids.cpu().numpy() - elif "tensorflow" in str(type(token_ids)): - token_ids = token_ids.numpy() - elif "jaxlib" in str(type(token_ids)): - token_ids = token_ids.tolist() + token_ids = token_ids.cpu().numpy() # now the token ids are either a numpy array, or a list of lists if isinstance(token_ids, np.ndarray): token_ids = token_ids.tolist() diff --git a/src/transformers/tokenization_mistral_common.py b/src/transformers/tokenization_mistral_common.py index a362a7c8b066..c81ed6831552 100644 --- a/src/transformers/tokenization_mistral_common.py +++ b/src/transformers/tokenization_mistral_common.py @@ -1239,7 +1239,7 @@ def pad( else: raise ValueError( f"type of {first_element} unknown: {type(first_element)}. " - "Should be one of a python, numpy, pytorch or tensorflow object." + "Should be one of a python, numpy, or pytorch object." ) for key, value in encoded_inputs.items(): diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index a2d84b024057..055580c96177 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -452,8 +452,6 @@ class TrainerMemoryTracker: self._memory_tracker.stop_and_update_metrics(metrics) ``` - At the moment GPU tracking is only for `pytorch`, but can be extended to support `tensorflow`. - To understand this class' intricacies please read the documentation of [`~Trainer.log_metrics`]. """ diff --git a/src/transformers/utils/dummy_tensorflow_text_objects.py b/src/transformers/utils/dummy_tensorflow_text_objects.py deleted file mode 100644 index 70c7ad5cbf40..000000000000 --- a/src/transformers/utils/dummy_tensorflow_text_objects.py +++ /dev/null @@ -1,9 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class TFBertTokenizer(metaclass=DummyObject): - _backends = ["tensorflow_text"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tensorflow_text"]) diff --git a/tests/repo_utils/test_tests_fetcher.py b/tests/repo_utils/test_tests_fetcher.py index 727cb2affa08..0a7917e6033c 100644 --- a/tests/repo_utils/test_tests_fetcher.py +++ b/tests/repo_utils/test_tests_fetcher.py @@ -151,15 +151,14 @@ def create_tmp_repo(tmp_dir, models=None): example_dir = tmp_dir / "examples" example_dir.mkdir(exist_ok=True) - for framework in ["flax", "pytorch", "tensorflow"]: - framework_dir = example_dir / framework - framework_dir.mkdir(exist_ok=True) - with open(framework_dir / f"test_{framework}_examples.py", "w") as f: - f.write("""test_args = "run_glue.py"\n""") - glue_dir = framework_dir / "text-classification" - glue_dir.mkdir(exist_ok=True) - with open(glue_dir / "run_glue.py", "w") as f: - f.write("from transformers import BertModel\n\ncode") + framework_dir = example_dir / "pytorch" + framework_dir.mkdir(exist_ok=True) + with open(framework_dir / "test_pytorch_examples.py", "w") as f: + f.write("""test_args = "run_glue.py"\n""") + glue_dir = framework_dir / "text-classification" + glue_dir.mkdir(exist_ok=True) + with open(glue_dir / "run_glue.py", "w") as f: + f.write("from transformers import BertModel\n\ncode") repo.index.add(["examples", "src", "tests"]) repo.index.commit("Initial commit") @@ -525,27 +524,15 @@ def test_init_test_examples_dependencies(self): create_tmp_repo(tmp_folder) expected_example_deps = { - "examples/flax/test_flax_examples.py": [ - "examples/flax/text-classification/run_glue.py", - "examples/flax/test_flax_examples.py", - ], "examples/pytorch/test_pytorch_examples.py": [ "examples/pytorch/text-classification/run_glue.py", "examples/pytorch/test_pytorch_examples.py", ], - "examples/tensorflow/test_tensorflow_examples.py": [ - "examples/tensorflow/text-classification/run_glue.py", - "examples/tensorflow/test_tensorflow_examples.py", - ], } expected_examples = { - "examples/flax/test_flax_examples.py", - "examples/flax/text-classification/run_glue.py", "examples/pytorch/test_pytorch_examples.py", "examples/pytorch/text-classification/run_glue.py", - "examples/tensorflow/test_tensorflow_examples.py", - "examples/tensorflow/text-classification/run_glue.py", } with patch_transformer_repo_path(tmp_folder): @@ -565,12 +552,8 @@ def test_create_reverse_dependency_map(self): "src/transformers/__init__.py", "src/transformers/models/bert/__init__.py", "tests/models/bert/test_modeling_bert.py", - "examples/flax/test_flax_examples.py", - "examples/flax/text-classification/run_glue.py", "examples/pytorch/test_pytorch_examples.py", "examples/pytorch/text-classification/run_glue.py", - "examples/tensorflow/test_tensorflow_examples.py", - "examples/tensorflow/text-classification/run_glue.py", } assert set(reverse_map["src/transformers/models/bert/modeling_bert.py"]) == expected_bert_deps @@ -586,12 +569,8 @@ def test_create_reverse_dependency_map(self): "src/transformers/modeling_utils.py", "tests/test_modeling_common.py", "tests/models/bert/test_modeling_bert.py", - "examples/flax/test_flax_examples.py", - "examples/flax/text-classification/run_glue.py", "examples/pytorch/test_pytorch_examples.py", "examples/pytorch/text-classification/run_glue.py", - "examples/tensorflow/test_tensorflow_examples.py", - "examples/tensorflow/text-classification/run_glue.py", } assert set(reverse_map["src/transformers/__init__.py"]) == expected_init_deps @@ -600,12 +579,8 @@ def test_create_reverse_dependency_map(self): "src/transformers/models/bert/configuration_bert.py", "src/transformers/models/bert/modeling_bert.py", "tests/models/bert/test_modeling_bert.py", - "examples/flax/test_flax_examples.py", - "examples/flax/text-classification/run_glue.py", "examples/pytorch/test_pytorch_examples.py", "examples/pytorch/text-classification/run_glue.py", - "examples/tensorflow/test_tensorflow_examples.py", - "examples/tensorflow/text-classification/run_glue.py", } assert set(reverse_map["src/transformers/models/bert/__init__.py"]) == expected_init_deps @@ -620,12 +595,8 @@ def test_create_reverse_dependency_map(self): "src/transformers/models/bert/configuration_bert.py", "src/transformers/models/bert/modeling_bert.py", "tests/models/bert/test_modeling_bert.py", - "examples/flax/test_flax_examples.py", - "examples/flax/text-classification/run_glue.py", "examples/pytorch/test_pytorch_examples.py", "examples/pytorch/text-classification/run_glue.py", - "examples/tensorflow/test_tensorflow_examples.py", - "examples/tensorflow/text-classification/run_glue.py", } assert set(reverse_map["src/transformers/models/bert/__init__.py"]) == expected_init_deps @@ -639,9 +610,7 @@ def test_infer_tests_to_run(self): commit_changes("src/transformers/models/bert/modeling_bert.py", BERT_MODEL_FILE_NEW_CODE, repo) example_tests = { - "examples/flax/test_flax_examples.py", "examples/pytorch/test_pytorch_examples.py", - "examples/tensorflow/test_tensorflow_examples.py", } with patch_transformer_repo_path(tmp_folder): diff --git a/tests/sagemaker/conftest.py b/tests/sagemaker/conftest.py index 89b89966d542..4dbfac3588ee 100644 --- a/tests/sagemaker/conftest.py +++ b/tests/sagemaker/conftest.py @@ -12,7 +12,7 @@ @dataclass class SageMakerTestEnvironment: - framework: str + framework: str = "pytorch" role = "arn:aws:iam::558105141721:role/sagemaker_execution_role" hyperparameters = { "task_name": "mnli", @@ -30,18 +30,11 @@ class SageMakerTestEnvironment: @property def metric_definitions(self) -> str: - if self.framework == "pytorch": - return [ - {"Name": "train_runtime", "Regex": r"train_runtime.*=\D*(.*?)$"}, - {"Name": "eval_accuracy", "Regex": r"eval_accuracy.*=\D*(.*?)$"}, - {"Name": "eval_loss", "Regex": r"eval_loss.*=\D*(.*?)$"}, - ] - else: - return [ - {"Name": "train_runtime", "Regex": r"train_runtime.*=\D*(.*?)$"}, - {"Name": "eval_accuracy", "Regex": r"loss.*=\D*(.*?)]?$"}, - {"Name": "eval_loss", "Regex": r"sparse_categorical_accuracy.*=\D*(.*?)]?$"}, - ] + return [ + {"Name": "train_runtime", "Regex": r"train_runtime.*=\D*(.*?)$"}, + {"Name": "eval_accuracy", "Regex": r"eval_accuracy.*=\D*(.*?)$"}, + {"Name": "eval_loss", "Regex": r"eval_loss.*=\D*(.*?)$"}, + ] @property def base_job_name(self) -> str: @@ -53,10 +46,7 @@ def test_path(self) -> str: @property def image_uri(self) -> str: - if self.framework == "pytorch": - return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04" - else: - return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-training:2.4.1-transformers4.6.1-gpu-py37-cu110-ubuntu18.04" + return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04" @pytest.fixture(scope="class") diff --git a/tests/sagemaker/test_multi_node_data_parallel.py b/tests/sagemaker/test_multi_node_data_parallel.py index 2ea029a28551..602a90e6d8e8 100644 --- a/tests/sagemaker/test_multi_node_data_parallel.py +++ b/tests/sagemaker/test_multi_node_data_parallel.py @@ -36,13 +36,6 @@ "instance_type": "ml.p3.16xlarge", "results": {"train_runtime": 600, "eval_accuracy": 0.7, "eval_loss": 0.6}, }, - { - "framework": "tensorflow", - "script": "run_tf_dist.py", - "model_name_or_path": "distilbert/distilbert-base-cased", - "instance_type": "ml.p3.16xlarge", - "results": {"train_runtime": 600, "eval_accuracy": 0.6, "eval_loss": 0.7}, - }, ] ) class MultiNodeTest(unittest.TestCase): diff --git a/tests/sagemaker/test_single_node_gpu.py b/tests/sagemaker/test_single_node_gpu.py index 53d966bd1e85..c1902797391d 100644 --- a/tests/sagemaker/test_single_node_gpu.py +++ b/tests/sagemaker/test_single_node_gpu.py @@ -29,13 +29,6 @@ "instance_type": "ml.g4dn.xlarge", "results": {"train_runtime": 650, "eval_accuracy": 0.6, "eval_loss": 0.9}, }, - { - "framework": "tensorflow", - "script": "run_tf.py", - "model_name_or_path": "distilbert/distilbert-base-cased", - "instance_type": "ml.g4dn.xlarge", - "results": {"train_runtime": 600, "eval_accuracy": 0.3, "eval_loss": 0.9}, - }, ] ) class SingleNodeTest(unittest.TestCase): diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index ed6ad85bf8fd..25bc2fde3663 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -313,7 +313,6 @@ docs/source/en/troubleshooting.md src/transformers/activations.py src/transformers/audio_utils.py src/transformers/commands/add_new_model_like.py -src/transformers/commands/convert.py src/transformers/commands/download.py src/transformers/commands/env.py src/transformers/commands/run.py @@ -807,7 +806,6 @@ src/transformers/utils/dummy_pt_objects.py src/transformers/utils/dummy_sentencepiece_and_tokenizers_objects.py src/transformers/utils/dummy_sentencepiece_objects.py src/transformers/utils/dummy_speech_objects.py -src/transformers/utils/dummy_tensorflow_text_objects.py src/transformers/utils/dummy_tokenizers_objects.py src/transformers/utils/dummy_vision_objects.py src/transformers/utils/fx.py diff --git a/utils/notification_service.py b/utils/notification_service.py index 410d3ba78507..ca2085c74b25 100644 --- a/utils/notification_service.py +++ b/utils/notification_service.py @@ -37,7 +37,6 @@ "run_models_gpu": "Models", "run_trainer_and_fsdp_gpu": "Trainer & FSDP", "run_pipelines_torch_gpu": "PyTorch pipelines", - "run_pipelines_tf_gpu": "TensorFlow pipelines", "run_examples_gpu": "Examples directory", "run_torch_cuda_extensions_gpu": "DeepSpeed", "run_quantization_torch_gpu": "Quantization", @@ -48,7 +47,6 @@ "Models": "model", "Trainer & FSDP": "trainer_and_fsdp", "PyTorch pipelines": "torch_pipeline", - "TensorFlow pipelines": "tf_pipeline", "Examples directory": "example", "DeepSpeed": "deepspeed", "Quantization": "quantization", @@ -390,12 +388,10 @@ def per_model_sum(model_category_dict): # Model job has a special form for reporting if job_name == "run_models_gpu": pytorch_specific_failures = dict_failed.pop("PyTorch") - tensorflow_specific_failures = dict_failed.pop("TensorFlow") other_failures = dicts_to_sum(dict_failed.values()) failures[k] = { "PyTorch": pytorch_specific_failures, - "TensorFlow": tensorflow_specific_failures, "other": other_failures, } @@ -429,8 +425,6 @@ def per_model_sum(model_category_dict): device_report_values = [ value["PyTorch"]["single"], value["PyTorch"]["multi"], - value["TensorFlow"]["single"], - value["TensorFlow"]["multi"], sum(value["other"].values()), ] @@ -1168,8 +1162,6 @@ def pop_default(l: list[Any], i: int, default: Any) -> Any: test_categories = [ "PyTorch", - "TensorFlow", - "Flax", "Tokenizers", "Pipelines", "Trainer", @@ -1249,12 +1241,6 @@ def pop_default(l: list[Any], i: int, default: Any) -> Any: if re.search("tests/quantization", line): matrix_job_results[matrix_name]["failed"]["Quantization"][artifact_gpu] += 1 - elif re.search("test_modeling_tf_", line): - matrix_job_results[matrix_name]["failed"]["TensorFlow"][artifact_gpu] += 1 - - elif re.search("test_modeling_flax_", line): - matrix_job_results[matrix_name]["failed"]["Flax"][artifact_gpu] += 1 - elif re.search("test_modeling", line): matrix_job_results[matrix_name]["failed"]["PyTorch"][artifact_gpu] += 1 @@ -1280,7 +1266,6 @@ def pop_default(l: list[Any], i: int, default: Any) -> Any: # Additional runs additional_files = { "PyTorch pipelines": "run_pipelines_torch_gpu_test_reports", - "TensorFlow pipelines": "run_pipelines_tf_gpu_test_reports", "Examples directory": "run_examples_gpu_test_reports", "DeepSpeed": "run_torch_cuda_extensions_gpu_test_reports", } @@ -1288,9 +1273,7 @@ def pop_default(l: list[Any], i: int, default: Any) -> Any: if ci_event in ["push", "Nightly CI"] or ci_event.startswith("Past CI"): del additional_files["Examples directory"] del additional_files["PyTorch pipelines"] - del additional_files["TensorFlow pipelines"] elif ci_event.startswith("Scheduled CI (AMD)"): - del additional_files["TensorFlow pipelines"] del additional_files["DeepSpeed"] elif ci_event.startswith("Push CI (AMD)"): additional_files = {} diff --git a/utils/past_ci_versions.py b/utils/past_ci_versions.py deleted file mode 100644 index 858f7184d707..000000000000 --- a/utils/past_ci_versions.py +++ /dev/null @@ -1,126 +0,0 @@ -import argparse -import os - - -past_versions_testing = { - "pytorch": { - "1.13": { - "torch": "1.13.1", - "torchvision": "0.14.1", - "torchaudio": "0.13.1", - "python": 3.9, - "cuda": "cu116", - "install": ( - "python3 -m pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1" - " --extra-index-url https://download.pytorch.org/whl/cu116" - ), - "base_image": "nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04", - }, - "1.12": { - "torch": "1.12.1", - "torchvision": "0.13.1", - "torchaudio": "0.12.1", - "python": 3.9, - "cuda": "cu113", - "install": ( - "python3 -m pip install --no-cache-dir -U torch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1" - " --extra-index-url https://download.pytorch.org/whl/cu113" - ), - "base_image": "nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04", - }, - "1.11": { - "torch": "1.11.0", - "torchvision": "0.12.0", - "torchaudio": "0.11.0", - "python": 3.9, - "cuda": "cu113", - "install": ( - "python3 -m pip install --no-cache-dir -U torch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0" - " --extra-index-url https://download.pytorch.org/whl/cu113" - ), - "base_image": "nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04", - }, - "1.10": { - "torch": "1.10.2", - "torchvision": "0.11.3", - "torchaudio": "0.10.2", - "python": 3.9, - "cuda": "cu113", - "install": ( - "python3 -m pip install --no-cache-dir -U torch==1.10.2 torchvision==0.11.3 torchaudio==0.10.2" - " --extra-index-url https://download.pytorch.org/whl/cu113" - ), - "base_image": "nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04", - }, - # torchaudio < 0.10 has no CUDA-enabled binary distributions - "1.9": { - "torch": "1.9.1", - "torchvision": "0.10.1", - "torchaudio": "0.9.1", - "python": 3.9, - "cuda": "cu111", - "install": ( - "python3 -m pip install --no-cache-dir -U torch==1.9.1 torchvision==0.10.1 torchaudio==0.9.1" - " --extra-index-url https://download.pytorch.org/whl/cu111" - ), - "base_image": "nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04", - }, - }, - "tensorflow": { - "2.11": { - "tensorflow": "2.11.1", - "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.11.1", - "base_image": "nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04", - }, - "2.10": { - "tensorflow": "2.10.1", - "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.10.1", - "base_image": "nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04", - }, - "2.9": { - "tensorflow": "2.9.3", - "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.9.3", - "base_image": "nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04", - }, - "2.8": { - "tensorflow": "2.8.2", - "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.8.2", - "base_image": "nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04", - }, - "2.7": { - "tensorflow": "2.7.3", - "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.7.3", - "base_image": "nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04", - }, - "2.6": { - "tensorflow": "2.6.5", - "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.6.5", - "base_image": "nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04", - }, - "2.5": { - "tensorflow": "2.5.3", - "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.5.3", - "base_image": "nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04", - }, - }, -} - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("Choose the framework and version to install") - parser.add_argument( - "--framework", help="The framework to install. Should be `torch` or `tensorflow`", type=str, required=True - ) - parser.add_argument("--version", help="The version of the framework to install.", type=str, required=True) - args = parser.parse_args() - - info = past_versions_testing[args.framework][args.version] - - os.system(f"echo \"export INSTALL_CMD='{info['install']}'\" >> ~/.profile") - print(f"echo \"export INSTALL_CMD='{info['install']}'\" >> ~/.profile") - - cuda = "" - if args.framework == "pytorch": - cuda = info["cuda"] - os.system(f"echo \"export CUDA='{cuda}'\" >> ~/.profile") - print(f"echo \"export CUDA='{cuda}'\" >> ~/.profile") diff --git a/utils/print_env.py b/utils/print_env.py index d693d7e83b6f..839bebc88c0a 100644 --- a/utils/print_env.py +++ b/utils/print_env.py @@ -64,15 +64,6 @@ except ImportError: print("DeepSpeed version:", None) -try: - import tensorflow as tf - - print("TensorFlow version:", tf.__version__) - print("TF GPUs available:", bool(tf.config.list_physical_devices("GPU"))) - print("Number of TF GPUs available:", len(tf.config.list_physical_devices("GPU"))) -except ImportError: - print("TensorFlow version:", None) - try: import torchcodec diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index d200fc83b742..c1ea0a29cc37 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -817,26 +817,22 @@ def init_test_examples_dependencies() -> tuple[dict[str, list[str]], list[str]]: """ test_example_deps = {} all_examples = [] - for framework in ["flax", "pytorch", "tensorflow"]: - test_files = list((PATH_TO_EXAMPLES / framework).glob("test_*.py")) - all_examples.extend(test_files) - # Remove the files at the root of examples/framework since they are not proper examples (they are either utils - # or example test files). - examples = [ - f for f in (PATH_TO_EXAMPLES / framework).glob("**/*.py") if f.parent != PATH_TO_EXAMPLES / framework + + test_files = list((PATH_TO_EXAMPLES / "pytorch").glob("test_*.py")) + all_examples.extend(test_files) + # Remove the files at the root of examples/framework since they are not proper examples (they are either utils + # or example test files). + examples = [f for f in (PATH_TO_EXAMPLES / "pytorch").glob("**/*.py") if f.parent != PATH_TO_EXAMPLES / "pytorch"] + all_examples.extend(examples) + for test_file in test_files: + with open(test_file, "r", encoding="utf-8") as f: + content = f.read() + # Map all examples to the test files found in examples/pytorch. + test_example_deps[str(test_file.relative_to(PATH_TO_REPO))] = [ + str(e.relative_to(PATH_TO_REPO)) for e in examples if e.name in content ] - all_examples.extend(examples) - for test_file in test_files: - with open(test_file, "r", encoding="utf-8") as f: - content = f.read() - # Map all examples to the test files found in examples/framework. - test_example_deps[str(test_file.relative_to(PATH_TO_REPO))] = [ - str(e.relative_to(PATH_TO_REPO)) for e in examples if e.name in content - ] - # Also map the test files to themselves. - test_example_deps[str(test_file.relative_to(PATH_TO_REPO))].append( - str(test_file.relative_to(PATH_TO_REPO)) - ) + # Also map the test files to themselves. + test_example_deps[str(test_file.relative_to(PATH_TO_REPO))].append(str(test_file.relative_to(PATH_TO_REPO))) return test_example_deps, all_examples diff --git a/utils/update_metadata.py b/utils/update_metadata.py index 9f04300382e4..0e80ca2b866d 100755 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -51,10 +51,7 @@ transformers_module = direct_transformers_import(TRANSFORMERS_PATH) -# Regexes that match TF/Flax/PT model names. -_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") -_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") -# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes. +# Regexes that match model names _re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration|ForRetrieval)") @@ -157,21 +154,12 @@ def get_frameworks_table() -> pd.DataFrame: config.replace("Config", ""): model_type for model_type, config in config_mapping_names.items() } - # Dictionaries flagging if each model prefix has a backend in PT/TF/Flax. pt_models = collections.defaultdict(bool) - tf_models = collections.defaultdict(bool) - flax_models = collections.defaultdict(bool) # Let's lookup through all transformers object (once) and find if models are supported by a given backend. for attr_name in dir(transformers_module): lookup_dict = None - if _re_tf_models.match(attr_name) is not None: - lookup_dict = tf_models - attr_name = _re_tf_models.match(attr_name).groups()[0] - elif _re_flax_models.match(attr_name) is not None: - lookup_dict = flax_models - attr_name = _re_flax_models.match(attr_name).groups()[0] - elif _re_pt_models.match(attr_name) is not None: + if _re_pt_models.match(attr_name) is not None: lookup_dict = pt_models attr_name = _re_pt_models.match(attr_name).groups()[0] @@ -183,14 +171,12 @@ def get_frameworks_table() -> pd.DataFrame: # Try again after removing the last word in the name attr_name = "".join(camel_case_split(attr_name)[:-1]) - all_models = set(list(pt_models.keys()) + list(tf_models.keys()) + list(flax_models.keys())) + all_models = set(pt_models.keys()) all_models = list(all_models) all_models.sort() data = {"model_type": all_models} data["pytorch"] = [pt_models[t] for t in all_models] - data["tensorflow"] = [tf_models[t] for t in all_models] - data["flax"] = [flax_models[t] for t in all_models] # Now let's find the right processing class for each model. In order we check if there is a Processor, then a # Tokenizer, then a FeatureExtractor, then an ImageProcessor @@ -227,12 +213,10 @@ def update_pipeline_and_auto_class_table(table: dict[str, tuple[str, str]]) -> d """ auto_modules = [ transformers_module.models.auto.modeling_auto, - transformers_module.models.auto.modeling_tf_auto, - transformers_module.models.auto.modeling_flax_auto, ] for pipeline_tag, model_mapping, auto_class in PIPELINE_TAGS_AND_AUTO_MODELS: - model_mappings = [model_mapping, f"TF_{model_mapping}", f"FLAX_{model_mapping}"] - auto_classes = [auto_class, f"TF_{auto_class}", f"Flax_{auto_class}"] + model_mappings = [model_mapping] + auto_classes = [auto_class] # Loop through all three frameworks for module, cls, mapping in zip(auto_modules, auto_classes, model_mappings): # The type of pipeline may not exist in this framework From f7239b8ad696f7a3ab51bd0a0da343a5e57f1f00 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 10 Sep 2025 20:29:15 +0200 Subject: [PATCH 11/35] still grinding --- src/transformers/commands/transformers_cli.py | 2 - src/transformers/modeling_utils.py | 37 ++++------------- .../models/eomt/image_processing_eomt.py | 2 +- src/transformers/onnx/__main__.py | 16 +------- .../pipelines/automatic_speech_recognition.py | 5 --- .../pipelines/document_question_answering.py | 27 +++++------- .../pipelines/feature_extraction.py | 7 +--- .../pipelines/image_feature_extraction.py | 10 ++--- .../pipelines/image_segmentation.py | 11 ++--- .../pipelines/object_detection.py | 8 +--- src/transformers/pipelines/text_to_audio.py | 3 -- .../pipelines/zero_shot_object_detection.py | 12 ++---- .../tokenization_mistral_common.py | 7 +--- .../fixtures/add_distilbert_like_config.json | 19 --------- .../camembert/test_tokenization_camembert.py | 3 -- .../test_pipelines_text_generation.py | 10 +---- tests/tokenization/test_tokenization_utils.py | 1 - .../import_structure_raw_register.py | 10 ++--- ...import_structure_register_with_comments.py | 10 ++--- tests/utils/test_import_structure.py | 3 +- utils/add_pipeline_model_mapping_to_test.py | 41 ++++--------------- utils/tests_fetcher.py | 15 ++----- 22 files changed, 53 insertions(+), 206 deletions(-) delete mode 100644 tests/fixtures/add_distilbert_like_config.json diff --git a/src/transformers/commands/transformers_cli.py b/src/transformers/commands/transformers_cli.py index 00eaff01a4ef..1a283a1c512c 100644 --- a/src/transformers/commands/transformers_cli.py +++ b/src/transformers/commands/transformers_cli.py @@ -18,7 +18,6 @@ from transformers.commands.add_fast_image_processor import AddFastImageProcessorCommand from transformers.commands.add_new_model_like import AddNewModelLikeCommand from transformers.commands.chat import ChatCommand -from transformers.commands.convert import ConvertCommand from transformers.commands.download import DownloadCommand from transformers.commands.env import EnvironmentCommand from transformers.commands.run import RunCommand @@ -39,7 +38,6 @@ def main(): # Register commands ChatCommand.register_subcommand(commands_parser) - ConvertCommand.register_subcommand(commands_parser) DownloadCommand.register_subcommand(commands_parser) EnvironmentCommand.register_subcommand(commands_parser) RunCommand.register_subcommand(commands_parser) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a116a30b6056..d5555f92b0b2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -503,13 +503,6 @@ def load_state_dict( # Use safetensors if possible if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): with safe_open(checkpoint_file, framework="pt") as f: - metadata = f.metadata() - - if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: - raise OSError( - f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " - "you save your model with the `save_pretrained` method." - ) state_dict = {} for k in f.keys(): if map_location == "meta": @@ -4956,34 +4949,18 @@ def from_pretrained( transformers_explicit_filename=transformers_explicit_filename, ) - is_sharded = sharded_metadata is not None is_quantized = hf_quantizer is not None is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None - if ( - is_safetensors_available() - and is_from_file - and not is_sharded - and checkpoint_files[0].endswith(".safetensors") - ): + # Just a helpful message in case we try to load safetensors files coming from old Transformers tf/flax classes + if is_safetensors_available() and is_from_file and checkpoint_files[0].endswith(".safetensors"): with safe_open(checkpoint_files[0], framework="pt") as f: metadata = f.metadata() - - if metadata is None: - # Assume it's a pytorch checkpoint (introduced for timm checkpoints) - pass - elif metadata.get("format") == "pt": - pass - elif metadata.get("format") == "tf": - raise ValueError("The safetensors file found has format `'tf'`, which is incompatible.") - elif metadata.get("format") == "flax": - raise ValueError("The safetensors file found has format `'flax'`, which is incompatible.") - elif metadata.get("format") == "mlx": - # This is a mlx file, we assume weights are compatible with pt - pass - else: - raise ValueError( - f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}" + if metadata is not None and metadata.get("format") in ["tf", "flax"]: + logger.warning( + "The safetensors checkpoint found has format `tf` or `flax`. This mean that the keys will very" + "likely not match to the model you are trying to load, and will be newly initialized. If it's the case " + "another warning will be raised later. Consider converting your checkpoint to the correct format." ) if gguf_file: diff --git a/src/transformers/models/eomt/image_processing_eomt.py b/src/transformers/models/eomt/image_processing_eomt.py index 93a440693dee..83bc70521019 100644 --- a/src/transformers/models/eomt/image_processing_eomt.py +++ b/src/transformers/models/eomt/image_processing_eomt.py @@ -557,7 +557,7 @@ def preprocess( Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels denoted with 0 (background) will be replaced with `ignore_index`. return_tensors (`str` or `TensorType`, *optional*): - The type of tensors to return. Can be `"pt"`, `"tf"`, `"np"`, or `"jax"`. + The type of tensors to return. Can be `"pt"` or `"np"`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): Channel format of the output image. Either `"channels_first"` or `"channels_last"`. input_data_format (`ChannelDimension` or `str`, *optional*): diff --git a/src/transformers/onnx/__main__.py b/src/transformers/onnx/__main__.py index e3dc6dfb78aa..db43126fd3fb 100644 --- a/src/transformers/onnx/__main__.py +++ b/src/transformers/onnx/__main__.py @@ -53,7 +53,6 @@ def export_with_optimum(args): "optimum.exporters.onnx", f"--model {args.model}", f"--task {args.feature}", - f"--framework {args.framework}" if args.framework is not None else "", f"{args.output}", ] proc = subprocess.Popen(cmd_line, stdout=subprocess.PIPE) @@ -72,9 +71,7 @@ def export_with_transformers(args): args.output.parent.mkdir(parents=True) # Allocate the model - model = FeaturesManager.get_model_from_feature( - args.feature, args.model, framework=args.framework, cache_dir=args.cache_dir - ) + model = FeaturesManager.get_model_from_feature(args.feature, args.model, cache_dir=args.cache_dir) model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature) onnx_config = model_onnx_config(model.config) @@ -199,17 +196,6 @@ def main(): parser.add_argument( "--atol", type=float, default=None, help="Absolute difference tolerance when validating the model." ) - parser.add_argument( - "--framework", - type=str, - choices=["pt", "tf"], - default=None, - help=( - "The framework to use for the ONNX export." - " If not provided, will attempt to use the local checkpoint's original framework" - " or what is available in the environment." - ), - ) parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.") parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.") parser.add_argument( diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 88071764aafb..35bd35fc6b29 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -168,11 +168,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): - framework (`str`, *optional*): - The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be - installed. If no framework is specified, will default to the one currently installed. If no framework is - specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if - no model is provided. device (Union[`int`, `torch.device`], *optional*): Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the model on the associated CUDA device id. diff --git a/src/transformers/pipelines/document_question_answering.py b/src/transformers/pipelines/document_question_answering.py index d45e756b0f3f..e6592a8ab4a3 100644 --- a/src/transformers/pipelines/document_question_answering.py +++ b/src/transformers/pipelines/document_question_answering.py @@ -331,12 +331,11 @@ def preprocess( if input.get("image", None) is not None: image = load_image(input["image"], timeout=timeout) if self.image_processor is not None: - image_inputs = self.image_processor(images=image, return_tensors=self.framework) - if self.framework == "pt": - image_inputs = image_inputs.to(self.dtype) + image_inputs = self.image_processor(images=image, return_tensors="pt") + image_inputs = image_inputs.to(self.dtype) image_features.update(image_inputs) elif self.feature_extractor is not None: - image_features.update(self.feature_extractor(images=image, return_tensors=self.framework)) + image_features.update(self.feature_extractor(images=image, return_tensors="pt")) elif self.model_type == ModelType.VisionEncoderDecoder: raise ValueError("If you are using a VisionEncoderDecoderModel, you must provide a feature extractor") @@ -374,7 +373,7 @@ def preprocess( encoding = { "inputs": image_features["pixel_values"], "decoder_input_ids": self.tokenizer( - task_prompt, add_special_tokens=False, return_tensors=self.framework + task_prompt, add_special_tokens=False, return_tensors="pt" ).input_ids, "return_dict_in_generate": True, } @@ -417,12 +416,9 @@ def preprocess( # This logic mirrors the logic in the question_answering pipeline p_mask = [[tok != 1 for tok in encoding.sequence_ids(span_id)] for span_id in range(num_spans)] for span_idx in range(num_spans): - if self.framework == "pt": - span_encoding = {k: torch.tensor(v[span_idx : span_idx + 1]) for (k, v) in encoding.items()} - if "pixel_values" in image_features: - span_encoding["image"] = image_features["pixel_values"] - else: - raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline") + span_encoding = {k: torch.tensor(v[span_idx : span_idx + 1]) for (k, v) in encoding.items()} + if "pixel_values" in image_features: + span_encoding["image"] = image_features["pixel_values"] input_ids_span_idx = encoding["input_ids"][span_idx] # keep the cls_token unmasked (some models use it to indicate unanswerable questions) @@ -447,10 +443,7 @@ def preprocess( else: bbox.append([0] * 4) - if self.framework == "pt": - span_encoding["bbox"] = torch.tensor(bbox).unsqueeze(0) - elif self.framework == "tf": - raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline") + span_encoding["bbox"] = torch.tensor(bbox).unsqueeze(0) yield { **span_encoding, "p_mask": p_mask[span_idx], @@ -515,9 +508,9 @@ def postprocess_extractive_qa( for output in model_outputs: words = output["words"] - if self.framework == "pt" and output["start_logits"].dtype in (torch.bfloat16, torch.float16): + if output["start_logits"].dtype in (torch.bfloat16, torch.float16): output["start_logits"] = output["start_logits"].float() - if self.framework == "pt" and output["end_logits"].dtype in (torch.bfloat16, torch.float16): + if output["end_logits"].dtype in (torch.bfloat16, torch.float16): output["end_logits"] = output["end_logits"].float() starts, ends, scores, min_null_score = select_starts_ends( diff --git a/src/transformers/pipelines/feature_extraction.py b/src/transformers/pipelines/feature_extraction.py index 9c8005d05f22..69c341d3d846 100644 --- a/src/transformers/pipelines/feature_extraction.py +++ b/src/transformers/pipelines/feature_extraction.py @@ -62,7 +62,7 @@ def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_ten return preprocess_params, {}, postprocess_params def preprocess(self, inputs, **tokenize_kwargs) -> dict[str, GenericTensor]: - model_inputs = self.tokenizer(inputs, return_tensors=self.framework, **tokenize_kwargs) + model_inputs = self.tokenizer(inputs, return_tensors="pt", **tokenize_kwargs) return model_inputs def _forward(self, model_inputs): @@ -73,10 +73,7 @@ def postprocess(self, model_outputs, return_tensors=False): # [0] is the first available tensor, logits or last_hidden_state. if return_tensors: return model_outputs[0] - if self.framework == "pt": - return model_outputs[0].tolist() - elif self.framework == "tf": - return model_outputs[0].numpy().tolist() + return model_outputs[0].tolist() def __call__(self, *args: Union[str, list[str]], **kwargs: Any) -> Union[Any, list[Any]]: """ diff --git a/src/transformers/pipelines/image_feature_extraction.py b/src/transformers/pipelines/image_feature_extraction.py index a87ecafb684e..d049957a4138 100644 --- a/src/transformers/pipelines/image_feature_extraction.py +++ b/src/transformers/pipelines/image_feature_extraction.py @@ -66,9 +66,8 @@ def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, def preprocess(self, image, timeout=None, **image_processor_kwargs) -> dict[str, GenericTensor]: image = load_image(image, timeout=timeout) - model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs) - if self.framework == "pt": - model_inputs = model_inputs.to(self.dtype) + model_inputs = self.image_processor(image, return_tensors="pt", **image_processor_kwargs) + model_inputs = model_inputs.to(self.dtype) return model_inputs def _forward(self, model_inputs): @@ -90,10 +89,7 @@ def postprocess(self, model_outputs, pool=None, return_tensors=False): if return_tensors: return outputs - if self.framework == "pt": - return outputs.tolist() - elif self.framework == "tf": - return outputs.numpy().tolist() + return outputs.tolist() def __call__(self, *args: Union[str, "Image.Image", list["Image.Image"], list[str]], **kwargs: Any) -> list[Any]: """ diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py index ed09f15d5e13..a6c7ef362d95 100644 --- a/src/transformers/pipelines/image_segmentation.py +++ b/src/transformers/pipelines/image_segmentation.py @@ -68,9 +68,6 @@ class ImageSegmentationPipeline(Pipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self.framework == "tf": - raise ValueError(f"The {self.__class__} is only available in PyTorch.") - requires_backends(self, "vision") mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES.copy() mapping.update(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES) @@ -160,18 +157,16 @@ def preprocess(self, image, subtask=None, timeout=None): else: kwargs = {"task_inputs": [subtask]} inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs) - if self.framework == "pt": - inputs = inputs.to(self.dtype) + inputs = inputs.to(self.dtype) inputs["task_inputs"] = self.tokenizer( inputs["task_inputs"], padding="max_length", max_length=self.model.config.task_seq_len, - return_tensors=self.framework, + return_tensors="pt", )["input_ids"] else: inputs = self.image_processor(images=[image], return_tensors="pt") - if self.framework == "pt": - inputs = inputs.to(self.dtype) + inputs = inputs.to(self.dtype) inputs["target_size"] = target_size return inputs diff --git a/src/transformers/pipelines/object_detection.py b/src/transformers/pipelines/object_detection.py index 0db67f84d248..49739e383810 100644 --- a/src/transformers/pipelines/object_detection.py +++ b/src/transformers/pipelines/object_detection.py @@ -56,9 +56,6 @@ class ObjectDetectionPipeline(Pipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self.framework == "tf": - raise ValueError(f"The {self.__class__} is only available in PyTorch.") - requires_backends(self, "vision") mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES.copy() mapping.update(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES) @@ -121,8 +118,7 @@ def preprocess(self, image, timeout=None): image = load_image(image, timeout=timeout) target_size = torch.IntTensor([[image.height, image.width]]) inputs = self.image_processor(images=[image], return_tensors="pt") - if self.framework == "pt": - inputs = inputs.to(self.dtype) + inputs = inputs.to(self.dtype) if self.tokenizer is not None: inputs = self.tokenizer(text=inputs["words"], boxes=inputs["boxes"], return_tensors="pt") inputs["target_size"] = target_size @@ -191,8 +187,6 @@ def _get_bounding_box(self, box: "torch.Tensor") -> dict[str, int]: Returns: bbox (`dict[str, int]`): Dict containing the coordinates in corners format. """ - if self.framework != "pt": - raise ValueError("The ObjectDetectionPipeline is only available in PyTorch.") xmin, ymin, xmax, ymax = box.int().tolist() bbox = { "xmin": xmin, diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 17eaba1466b3..515736076892 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -100,9 +100,6 @@ def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, * # Legacy behaviour just uses the tokenizer while new models use the processor as a whole at any given time self.no_processor = no_processor - if self.framework == "tf": - raise ValueError("The TextToAudioPipeline is only available in PyTorch.") - self.vocoder = None if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values(): self.vocoder = ( diff --git a/src/transformers/pipelines/zero_shot_object_detection.py b/src/transformers/pipelines/zero_shot_object_detection.py index 55154af9ab3b..ef4563027c19 100644 --- a/src/transformers/pipelines/zero_shot_object_detection.py +++ b/src/transformers/pipelines/zero_shot_object_detection.py @@ -61,9 +61,6 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline): def __init__(self, **kwargs): super().__init__(**kwargs) - if self.framework == "tf": - raise ValueError(f"The {self.__class__} is only available in PyTorch.") - requires_backends(self, "vision") self.check_model_type(MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES) @@ -182,10 +179,9 @@ def preprocess(self, inputs, timeout=None): target_size = torch.tensor([[image.height, image.width]], dtype=torch.int32) for i, candidate_label in enumerate(candidate_labels): - text_inputs = self.tokenizer(candidate_label, return_tensors=self.framework) - image_features = self.image_processor(image, return_tensors=self.framework) - if self.framework == "pt": - image_features = image_features.to(self.dtype) + text_inputs = self.tokenizer(candidate_label, return_tensors="pt") + image_features = self.image_processor(image, return_tensors="pt") + image_features = image_features.to(self.dtype) yield { "is_last": i == len(candidate_labels) - 1, "target_size": target_size, @@ -236,8 +232,6 @@ def _get_bounding_box(self, box: "torch.Tensor") -> dict[str, int]: Returns: bbox (`dict[str, int]`): Dict containing the coordinates in corners format. """ - if self.framework != "pt": - raise ValueError("The ZeroShotObjectDetectionPipeline is only available in PyTorch.") xmin, ymin, xmax, ymax = box.int().tolist() bbox = { "xmin": xmin, diff --git a/src/transformers/tokenization_mistral_common.py b/src/transformers/tokenization_mistral_common.py index c81ed6831552..90d3b673e20e 100644 --- a/src/transformers/tokenization_mistral_common.py +++ b/src/transformers/tokenization_mistral_common.py @@ -1219,7 +1219,7 @@ def pad( encoded_inputs["attention_mask"] = [] return encoded_inputs - # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects + # If we have PyTorch/NumPy tensors/arrays as inputs, we cast them as python objects # and rebuild them afterwards if no return_tensors is specified # Note that we lose the specific device the tensor may be on for PyTorch @@ -1607,11 +1607,6 @@ def __call__( "`text_pair`, `text_target` and `text_pair_target` are not supported by `MistralCommonTokenizer`." ) - if return_tensors in ("tf", "jax"): - raise ValueError( - "`MistralCommonTokenizer` does not support `return_tensors='tf'` or `return_tensors='jax'`." - ) - def _is_valid_text_input(t): if isinstance(t, str): # Strings are fine diff --git a/tests/fixtures/add_distilbert_like_config.json b/tests/fixtures/add_distilbert_like_config.json deleted file mode 100644 index 6603796a0418..000000000000 --- a/tests/fixtures/add_distilbert_like_config.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "add_copied_from": true, - "old_model_type": "distilbert", - "new_model_patterns": { - "model_name": "BERT New", - "checkpoint": "huggingface/bert-new-base", - "model_type": "bert-new", - "model_lower_cased": "bert_new", - "model_camel_cased": "BertNew", - "model_upper_cased": "BERT_NEW", - "config_class": "BertNewConfig", - "tokenizer_class": "DistilBertTokenizer" - }, - "frameworks": [ - "pt", - "tf", - "flax" - ] -} diff --git a/tests/models/camembert/test_tokenization_camembert.py b/tests/models/camembert/test_tokenization_camembert.py index 33a49b33958e..704c757fc41d 100644 --- a/tests/models/camembert/test_tokenization_camembert.py +++ b/tests/models/camembert/test_tokenization_camembert.py @@ -18,7 +18,6 @@ from transformers import AddedToken, CamembertTokenizer, CamembertTokenizerFast from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow -from transformers.utils import is_torch_available from ...test_tokenization_common import TokenizerTesterMixin @@ -26,8 +25,6 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") SAMPLE_BPE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe.model") -FRAMEWORK = "pt" if is_torch_available() else "tf" - @require_sentencepiece @require_tokenizers diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 456c8ea922d8..983b1e60a097 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -369,11 +369,6 @@ def run_pipeline_test(self, text_generator, _): with self.assertRaises((ValueError, AssertionError)): outputs = text_generator("", add_special_tokens=False) - if text_generator.framework == "tf": - # TF generation does not support max_new_tokens, and it's impossible - # to control long generation with only max_length without - # fancy calculation, dismissing tests for now. - self.skipTest(reason="TF generation does not support max_new_tokens") # We don't care about infinite range models. # They already work. # Skip this test for XGLM, since it uses sinusoidal positional embeddings which are resized on-the-fly. @@ -480,10 +475,7 @@ def test_pipeline_accelerate_top_p(self): def test_pipeline_length_setting_warning(self): prompt = """Hello world""" text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", max_new_tokens=5) - if text_generator.model.framework == "tf": - logger = logging.get_logger("transformers.generation.tf_utils") - else: - logger = logging.get_logger("transformers.generation.utils") + logger = logging.get_logger("transformers.generation.utils") logger_msg = "Both `max_new_tokens`" # The beginning of the message to be checked in this test # Both are set by the user -> log warning diff --git a/tests/tokenization/test_tokenization_utils.py b/tests/tokenization/test_tokenization_utils.py index fc74223110f8..e1d98ae8ba4f 100644 --- a/tests/tokenization/test_tokenization_utils.py +++ b/tests/tokenization/test_tokenization_utils.py @@ -92,7 +92,6 @@ def test_pretrained_tokenizers(self): self.check_tokenizer_from_pretrained(GPT2Tokenizer) def test_tensor_type_from_str(self): - self.assertEqual(TensorType("tf"), TensorType.TENSORFLOW) self.assertEqual(TensorType("pt"), TensorType.PYTORCH) self.assertEqual(TensorType("np"), TensorType.NUMPY) diff --git a/tests/utils/import_structures/import_structure_raw_register.py b/tests/utils/import_structures/import_structure_raw_register.py index a1df4a9c2e93..b57772e901b9 100644 --- a/tests/utils/import_structures/import_structure_raw_register.py +++ b/tests/utils/import_structures/import_structure_raw_register.py @@ -28,19 +28,19 @@ def a0(): pass -@requires(backends=("torch", "tf")) +@requires(backends=("torch",)) class A1: def __init__(self): pass -@requires(backends=("torch", "tf")) +@requires(backends=("torch",)) def a1(): pass @requires( - backends=("torch", "tf") + backends=("torch",) ) class A2: def __init__(self): @@ -48,7 +48,7 @@ def __init__(self): @requires( - backends=("torch", "tf") + backends=("torch",) ) def a2(): pass @@ -57,7 +57,6 @@ def a2(): @requires( backends=( "torch", - "tf" ) ) class A3: @@ -68,7 +67,6 @@ def __init__(self): @requires( backends=( "torch", - "tf" ) ) def a3(): diff --git a/tests/utils/import_structures/import_structure_register_with_comments.py b/tests/utils/import_structures/import_structure_register_with_comments.py index aed2b196ca68..9d367cb10772 100644 --- a/tests/utils/import_structures/import_structure_register_with_comments.py +++ b/tests/utils/import_structures/import_structure_register_with_comments.py @@ -30,27 +30,27 @@ def b0(): pass -@requires(backends=("torch", "tf")) +@requires(backends=("torch",)) # That's a statement class B1: def __init__(self): pass -@requires(backends=("torch", "tf")) +@requires(backends=("torch",)) # That's a statement def b1(): pass -@requires(backends=("torch", "tf")) +@requires(backends=("torch",)) # That's a statement class B2: def __init__(self): pass -@requires(backends=("torch", "tf")) +@requires(backends=("torch",)) # That's a statement def b2(): pass @@ -59,7 +59,6 @@ def b2(): @requires( backends=( "torch", - "tf" ) ) # That's a statement @@ -71,7 +70,6 @@ def __init__(self): @requires( backends=( "torch", - "tf" ) ) # That's a statement diff --git a/tests/utils/test_import_structure.py b/tests/utils/test_import_structure.py index 1a4588a0d393..d69f4f0df1e4 100644 --- a/tests/utils/test_import_structure.py +++ b/tests/utils/test_import_structure.py @@ -55,7 +55,7 @@ def test_definition(self): frozenset({"torch"}): { "import_structure_register_with_duplicates": {"C0", "C1", "C2", "C3", "c0", "c1", "c2", "c3"} }, - frozenset({"tf", "torch"}): { + frozenset({"torch"}): { "import_structure_raw_register": {"A1", "A2", "A3", "a1", "a2", "a3"}, "import_structure_register_with_comments": {"B1", "B2", "B3", "b1", "b2", "b3"}, }, @@ -198,7 +198,6 @@ def test_import_spread(self): "backend,package_name,version_comparison,version", [ pytest.param(Backend("torch>=2.5 "), "torch", VersionComparison.GREATER_THAN_OR_EQUAL.value, "2.5"), - pytest.param(Backend("tf<=1"), "tf", VersionComparison.LESS_THAN_OR_EQUAL.value, "1"), pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL.value, "0.19.1"), ], ) diff --git a/utils/add_pipeline_model_mapping_to_test.py b/utils/add_pipeline_model_mapping_to_test.py index 636f018eb510..9e261da03d14 100644 --- a/utils/add_pipeline_model_mapping_to_test.py +++ b/utils/add_pipeline_model_mapping_to_test.py @@ -38,7 +38,7 @@ PIPELINE_TEST_MAPPING = {} for task in pipeline_test_mapping: - PIPELINE_TEST_MAPPING[task] = {"pt": None, "tf": None} + PIPELINE_TEST_MAPPING[task] = None # DO **NOT** add item to this set (unless the reason is approved) @@ -47,47 +47,26 @@ } -def get_framework(test_class): - """Infer the framework from the test class `test_class`.""" - - if "ModelTesterMixin" in [x.__name__ for x in test_class.__bases__]: - return "pt" - elif "TFModelTesterMixin" in [x.__name__ for x in test_class.__bases__]: - return "tf" - elif "FlaxModelTesterMixin" in [x.__name__ for x in test_class.__bases__]: - return "flax" - else: - return None - - -def get_mapping_for_task(task, framework): +def get_mapping_for_task(task): """Get mappings defined in `XXXPipelineTests` for the task `task`.""" # Use the cached results - if PIPELINE_TEST_MAPPING[task].get(framework, None) is not None: - return PIPELINE_TEST_MAPPING[task][framework] + if PIPELINE_TEST_MAPPING[task] is not None: + return PIPELINE_TEST_MAPPING[task] pipeline_test_class = pipeline_test_mapping[task]["test"] - mapping = None - - if framework == "pt": - mapping = getattr(pipeline_test_class, "model_mapping", None) - elif framework == "tf": - mapping = getattr(pipeline_test_class, "tf_model_mapping", None) + mapping = getattr(pipeline_test_class, "model_mapping", None) if mapping is not None: mapping = dict(mapping.items()) # cache the results - PIPELINE_TEST_MAPPING[task][framework] = mapping + PIPELINE_TEST_MAPPING[task] = mapping return mapping def get_model_for_pipeline_test(test_class, task): """Get the model architecture(s) related to the test class `test_class` for a pipeline `task`.""" - framework = get_framework(test_class) - if framework is None: - return None - mapping = get_mapping_for_task(task, framework) + mapping = get_mapping_for_task(task) if mapping is None: return None @@ -116,11 +95,7 @@ def get_pipeline_model_mapping_string(test_class): This will be a 1-line string. After this is added to a test file, `make style` will format it beautifully. """ - framework = get_framework(test_class) - if framework == "pt": - framework = "torch" default_value = "{}" - mapping = get_pipeline_model_mapping(test_class) if len(mapping) == 0: return "" @@ -135,7 +110,7 @@ def get_pipeline_model_mapping_string(test_class): value = model_classes.__name__ texts.append(f'"{task}": {value}') text = "{" + ", ".join(texts) + "}" - text = f"pipeline_model_mapping = {text} if is_{framework}_available() else {default_value}" + text = f"pipeline_model_mapping = {text} if is_torch_available() else {default_value}" return text diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index c1ea0a29cc37..f0ecb7842cd1 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -307,16 +307,7 @@ def get_impacted_files_from_tiny_model_summary(diff_with_last_commit: bool = Fal # Get the corresponding modeling file path for model_class in impacted_model_classes: module = reversed_structure[model_class] - framework = "" - if model_class.startswith("TF"): - framework = "tf" - elif model_class.startswith("Flax"): - framework = "flax" - fn = ( - f"modeling_{module.split('.')[-1]}.py" - if framework == "" - else f"modeling_{framework}_{module.split('.')[-1]}.py" - ) + fn = f"modeling_{module.split('.')[-1]}.py" files.add(f"src.transformers.{module}.{fn}".replace(".", os.path.sep).replace(f"{os.path.sep}py", ".py")) return sorted(files) @@ -808,7 +799,7 @@ def init_test_examples_dependencies() -> tuple[dict[str, list[str]], list[str]]: """ The test examples do not import from the examples (which are just scripts, not modules) so we need some extra care initializing the dependency map, which is the goal of this function. It initializes the dependency map for - example files by linking each example to the example test file for the example framework. + example files by linking each example to the example test file for the example folder. Returns: `Tuple[Dict[str, List[str]], List[str]]`: A tuple with two elements: the initialized dependency map which is a @@ -820,7 +811,7 @@ def init_test_examples_dependencies() -> tuple[dict[str, list[str]], list[str]]: test_files = list((PATH_TO_EXAMPLES / "pytorch").glob("test_*.py")) all_examples.extend(test_files) - # Remove the files at the root of examples/framework since they are not proper examples (they are either utils + # Remove the files at the root of examples/pytorch since they are not proper examples (they are either utils # or example test files). examples = [f for f in (PATH_TO_EXAMPLES / "pytorch").glob("**/*.py") if f.parent != PATH_TO_EXAMPLES / "pytorch"] all_examples.extend(examples) From b68ff887c179d459a2f04bb391312eee31c4e125 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 10 Sep 2025 23:21:11 +0200 Subject: [PATCH 12/35] always more --- docker/transformers-past-gpu/Dockerfile | 59 --- .../multiple_choice/utils_multiple_choice.py | 5 - .../legacy/token-classification/utils_ner.py | 5 - src/transformers/__init__.py | 1 - src/transformers/convert_graph_to_onnx.py | 498 ------------------ src/transformers/data/data_collator.py | 1 - src/transformers/data/datasets/glue.py | 4 - .../data/datasets/language_modeling.py | 16 - src/transformers/data/datasets/squad.py | 4 - src/transformers/image_transforms.py | 3 +- src/transformers/modeling_utils.py | 16 - .../aya_vision/processing_aya_vision.py | 2 - .../models/beit/image_processing_beit.py | 3 +- .../models/beit/image_processing_beit_fast.py | 3 +- .../models/blip/processing_blip.py | 2 - .../models/blip_2/processing_blip_2.py | 2 - .../models/chameleon/processing_chameleon.py | 2 - .../models/clipseg/processing_clipseg.py | 2 - .../models/cohere/tokenization_cohere_fast.py | 4 - .../processing_cohere2_vision.py | 2 - .../models/colpali/modular_colpali.py | 6 - .../models/colpali/processing_colpali.py | 6 - .../models/colqwen2/modular_colqwen2.py | 2 - .../models/colqwen2/processing_colqwen2.py | 6 - .../image_processing_conditional_detr.py | 21 +- src/transformers/models/csm/processing_csm.py | 2 - .../models/deepseek_vl/modular_deepseek_vl.py | 2 - .../deepseek_vl/processing_deepseek_vl.py | 2 - .../modular_deepseek_vl_hybrid.py | 2 - .../processing_deepseek_vl_hybrid.py | 2 - .../image_processing_deformable_detr.py | 22 +- .../deprecated/deta/image_processing_deta.py | 20 +- .../models/detr/image_processing_detr.py | 23 +- .../models/detr/image_processing_detr_fast.py | 2 +- .../models/dpt/image_processing_dpt.py | 3 +- .../models/dpt/image_processing_dpt_fast.py | 3 +- .../models/emu3/processing_emu3.py | 2 - .../models/florence2/modular_florence2.py | 2 - .../models/florence2/processing_florence2.py | 2 - .../models/glm4v/modular_glm4v.py | 2 - .../models/glm4v/processing_glm4v.py | 2 - .../models/got_ocr2/processing_got_ocr2.py | 2 - .../image_processing_grounding_dino.py | 19 +- .../models/internvl/processing_internvl.py | 2 - .../models/janus/processing_janus.py | 2 - .../models/llama4/processing_llama4.py | 2 - .../models/llava/processing_llava.py | 2 - .../processing_llava_next_video.py | 2 - .../image_processing_mobilenet_v2.py | 3 +- .../image_processing_mobilenet_v2_fast.py | 3 +- .../mobilevit/image_processing_mobilevit.py | 3 +- .../owlv2/image_processing_owlv2_fast.py | 1 - .../models/owlvit/image_processing_owlvit.py | 1 - .../owlvit/image_processing_owlvit_fast.py | 1 - .../models/paligemma/processing_paligemma.py | 2 - .../perception_lm/processing_perception_lm.py | 2 - .../models/pixtral/processing_pixtral.py | 2 - .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 2 - .../qwen2_5_vl/processing_qwen2_5_vl.py | 2 - .../models/qwen2_vl/processing_qwen2_vl.py | 2 - .../rt_detr/image_processing_rt_detr.py | 20 +- .../segformer/image_processing_segformer.py | 3 +- .../image_processing_segformer_fast.py | 3 +- .../video_llava/processing_video_llava.py | 2 - .../vitpose/image_processing_vitpose.py | 2 - .../models/voxtral/processing_voxtral.py | 2 - .../models/yolos/image_processing_yolos.py | 22 +- .../pipelines/audio_classification.py | 3 - .../pipelines/automatic_speech_recognition.py | 2 +- .../pipelines/depth_estimation.py | 5 +- src/transformers/pipelines/fill_mask.py | 2 +- .../pipelines/image_text_to_text.py | 4 +- src/transformers/pipelines/image_to_image.py | 3 +- .../pipelines/keypoint_matching.py | 4 +- src/transformers/pipelines/mask_generation.py | 37 +- .../pipelines/text2text_generation.py | 6 +- src/transformers/pipelines/text_to_audio.py | 2 +- .../pipelines/video_classification.py | 20 +- .../pipelines/visual_question_answering.py | 14 +- .../zero_shot_audio_classification.py | 16 +- .../pipelines/zero_shot_classification.py | 9 +- src/transformers/processing_utils.py | 2 - src/transformers/training_args.py | 9 +- tests/models/fsmt/test_modeling_fsmt.py | 2 +- tests/models/marian/test_modeling_marian.py | 2 +- tests/models/upernet/test_modeling_upernet.py | 2 +- ..._pipelines_automatic_speech_recognition.py | 23 +- tests/pipelines/test_pipelines_common.py | 17 +- .../test_pipelines_feature_extraction.py | 12 +- tests/pipelines/test_pipelines_fill_mask.py | 7 +- ...test_pipelines_image_feature_extraction.py | 14 +- .../test_pipelines_question_answering.py | 2 +- .../pipelines/test_pipelines_summarization.py | 2 +- .../test_pipelines_text2text_generation.py | 1 - .../test_pipelines_text_classification.py | 7 +- .../test_pipelines_text_generation.py | 6 - .../pipelines/test_pipelines_text_to_audio.py | 18 +- .../test_pipelines_token_classification.py | 26 +- tests/pipelines/test_pipelines_translation.py | 4 +- tests/pipelines/test_pipelines_zero_shot.py | 8 +- utils/check_config_attributes.py | 1 - utils/create_dummy_models.py | 125 +++-- utils/not_doctested.txt | 1 - utils/test_module/custom_pipeline.py | 2 +- utils/update_metadata.py | 35 +- utils/update_tiny_models.py | 12 - 106 files changed, 192 insertions(+), 1152 deletions(-) delete mode 100644 docker/transformers-past-gpu/Dockerfile delete mode 100644 src/transformers/convert_graph_to_onnx.py diff --git a/docker/transformers-past-gpu/Dockerfile b/docker/transformers-past-gpu/Dockerfile deleted file mode 100644 index 34bfbb19cef5..000000000000 --- a/docker/transformers-past-gpu/Dockerfile +++ /dev/null @@ -1,59 +0,0 @@ -ARG BASE_DOCKER_IMAGE -FROM $BASE_DOCKER_IMAGE -LABEL maintainer="Hugging Face" - -ARG DEBIAN_FRONTEND=noninteractive - -# Use login shell to read variables from `~/.profile` (to pass dynamic created variables between RUN commands) -SHELL ["sh", "-lc"] - -RUN apt update -RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg git-lfs libaio-dev -RUN git lfs install -RUN python3 -m pip install --no-cache-dir --upgrade pip - -ARG REF=main -RUN git clone https://github.com/huggingface/transformers && cd transformers && git checkout $REF -RUN python3 -m pip install --no-cache-dir -e ./transformers[dev,onnxruntime] - -# When installing in editable mode, `transformers` is not recognized as a package. -# this line must be added in order for python to be aware of transformers. -RUN cd transformers && python3 setup.py develop - -ARG FRAMEWORK -ARG VERSION - -# Control `setuptools` version to avoid some issues -RUN [ "$VERSION" != "1.10" ] && python3 -m pip install -U setuptools || python3 -m pip install -U "setuptools<=59.5" - -# Remove all frameworks -RUN python3 -m pip uninstall -y torch torchvision torchaudio - -# Get the libraries and their versions to install, and write installation command to `~/.profile`. -RUN python3 ./transformers/utils/past_ci_versions.py --framework $FRAMEWORK --version $VERSION - -# Install the target framework -RUN echo "INSTALL_CMD = $INSTALL_CMD" -RUN $INSTALL_CMD - -RUN [ "$FRAMEWORK" != "pytorch" ] && echo "`deepspeed-testing` installation is skipped" || python3 -m pip install --no-cache-dir ./transformers[deepspeed-testing] - -# Remove `accelerate`: it requires `torch`, and this causes import issues for TF-only testing -# We will install `accelerate@main` in Past CI workflow file -RUN python3 -m pip uninstall -y accelerate - -# Uninstall `torch-tensorrt` and `apex` shipped with the base image -RUN python3 -m pip uninstall -y torch-tensorrt apex - -# Pre-build **nightly** release of DeepSpeed, so it would be ready for testing (otherwise, the 1st deepspeed test will timeout) -RUN python3 -m pip uninstall -y deepspeed -# This has to be run inside the GPU VMs running the tests. (So far, it fails here due to GPU checks during compilation.) -# Issue: https://github.com/deepspeedai/DeepSpeed/issues/2010 -# RUN git clone https://github.com/deepspeedai/DeepSpeed && cd DeepSpeed && rm -rf build && \ -# DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_UTILS=1 python3 -m pip install . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1 - -RUN python3 -m pip install -U "itsdangerous<2.1.0" - -# When installing in editable mode, `transformers` is not recognized as a package. -# this line must be added in order for python to be aware of transformers. -RUN cd transformers && python3 setup.py develop diff --git a/examples/legacy/multiple_choice/utils_multiple_choice.py b/examples/legacy/multiple_choice/utils_multiple_choice.py index b62dabf76c56..64d3604f9ca4 100644 --- a/examples/legacy/multiple_choice/utils_multiple_choice.py +++ b/examples/legacy/multiple_choice/utils_multiple_choice.py @@ -78,11 +78,6 @@ class Split(Enum): from torch.utils.data import Dataset class MultipleChoiceDataset(Dataset): - """ - This will be superseded by a framework-agnostic approach - soon. - """ - features: list[InputFeatures] def __init__( diff --git a/examples/legacy/token-classification/utils_ner.py b/examples/legacy/token-classification/utils_ner.py index 809c95c26fd0..bfd792a250c3 100644 --- a/examples/legacy/token-classification/utils_ner.py +++ b/examples/legacy/token-classification/utils_ner.py @@ -208,11 +208,6 @@ def convert_examples_to_features( from torch.utils.data import Dataset class TokenClassificationDataset(Dataset): - """ - This will be superseded by a framework-agnostic approach - soon. - """ - features: list[InputFeatures] pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index # Use cross entropy ignore_index as padding label id so that only diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b4df2baa235e..71fe0ae4ad48 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -60,7 +60,6 @@ "audio_utils": [], "commands": [], "configuration_utils": ["PretrainedConfig"], - "convert_graph_to_onnx": [], "convert_slow_tokenizers_checkpoints_to_fast": [], "data": [ "DataProcessor", diff --git a/src/transformers/convert_graph_to_onnx.py b/src/transformers/convert_graph_to_onnx.py deleted file mode 100644 index 1d7109d333d3..000000000000 --- a/src/transformers/convert_graph_to_onnx.py +++ /dev/null @@ -1,498 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from argparse import ArgumentParser -from os import listdir, makedirs -from pathlib import Path -from typing import Optional - -from packaging.version import Version, parse - -from transformers.pipelines import Pipeline, pipeline -from transformers.tokenization_utils import BatchEncoding -from transformers.utils import ModelOutput, is_torch_available - - -# This is the minimal required version to -# support some ONNX Runtime features -ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0") - - -SUPPORTED_PIPELINES = [ - "feature-extraction", - "ner", - "sentiment-analysis", - "fill-mask", - "question-answering", - "text-generation", - "translation_en_to_fr", - "translation_en_to_de", - "translation_en_to_ro", -] - - -class OnnxConverterArgumentParser(ArgumentParser): - """ - Wraps all the script arguments supported to export transformers models to ONNX IR - """ - - def __init__(self): - super().__init__("ONNX Converter") - - self.add_argument( - "--pipeline", - type=str, - choices=SUPPORTED_PIPELINES, - default="feature-extraction", - ) - self.add_argument( - "--model", - type=str, - required=True, - help="Model's id or path (ex: google-bert/bert-base-cased)", - ) - self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: google-bert/bert-base-cased)") - self.add_argument( - "--framework", - type=str, - choices=["pt"], - default="pt", - help="Framework for loading the model", - ) - self.add_argument("--opset", type=int, default=11, help="ONNX opset to use") - self.add_argument( - "--check-loading", - action="store_true", - help="Check ONNX is able to load the model", - ) - self.add_argument( - "--use-external-format", - action="store_true", - help="Allow exporting model >= than 2Gb", - ) - self.add_argument( - "--quantize", - action="store_true", - help="Quantize the neural network to be run with int8", - ) - self.add_argument("output") - - -def generate_identified_filename(filename: Path, identifier: str) -> Path: - """ - Append a string-identifier at the end (before the extension, if any) to the provided filepath - - Args: - filename: pathlib.Path The actual path object we would like to add an identifier suffix - identifier: The suffix to add - - Returns: String with concatenated identifier at the end of the filename - """ - return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix) - - -def check_onnxruntime_requirements(minimum_version: Version): - """ - Check onnxruntime is installed and if the installed version match is recent enough - - Raises: - ImportError: If onnxruntime is not installed or too old version is found - """ - try: - import onnxruntime - - # Parse the version of the installed onnxruntime - ort_version = parse(onnxruntime.__version__) - - # We require 1.4.0 minimum - if ort_version < ORT_QUANTIZE_MINIMUM_VERSION: - raise ImportError( - f"We found an older version of onnxruntime ({onnxruntime.__version__}) " - f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n" - "Please update onnxruntime by running `pip install --upgrade onnxruntime`" - ) - - except ImportError: - raise ImportError( - "onnxruntime doesn't seem to be currently installed. " - "Please install the onnxruntime by running `pip install onnxruntime`" - " and relaunch the conversion." - ) - - -def ensure_valid_input(model, tokens, input_names): - """ - Ensure inputs are presented in the correct order, without any Non - - Args: - model: The model used to forward the input data - tokens: BatchEncoding holding the input data - input_names: The name of the inputs - - Returns: Tuple - - """ - print("Ensuring inputs are in correct order") - - model_args_name = model.forward.__code__.co_varnames - model_args, ordered_input_names = [], [] - for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument - if arg_name in input_names: - ordered_input_names.append(arg_name) - model_args.append(tokens[arg_name]) - else: - print(f"{arg_name} is not present in the generated input list.") - break - - print(f"Generated inputs order: {ordered_input_names}") - return ordered_input_names, tuple(model_args) - - -def infer_shapes(nlp: Pipeline, framework: str) -> tuple[list[str], list[str], dict, BatchEncoding]: - """ - Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model - - Args: - nlp: The pipeline object holding the model to be exported - framework: Not used anymore, only kept for BC - - Returns: - - - List of the inferred input variable names - - List of the inferred output variable names - - Dictionary with input/output variables names as key and shape tensor as value - - a BatchEncoding reference which was used to infer all the above information - """ - - def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int): - if isinstance(tensor, (tuple, list)): - return [build_shape_dict(name, t, is_input, seq_len) for t in tensor] - - else: - # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...) - axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"} - if is_input: - if len(tensor.shape) == 2: - axes[1] = "sequence" - else: - raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})") - else: - seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len] - axes.update(dict.fromkeys(seq_axes, "sequence")) - - print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}") - return axes - - tokens = nlp.tokenizer("This is a sample output", return_tensors="pt") - seq_len = tokens.input_ids.shape[-1] - outputs = nlp.model(**tokens) - if isinstance(outputs, ModelOutput): - outputs = outputs.to_tuple() - if not isinstance(outputs, (list, tuple)): - outputs = (outputs,) - - # Generate input names & axes - input_vars = list(tokens.keys()) - input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()} - - # flatten potentially grouped outputs (past for gpt2, attentions) - outputs_flat = [] - for output in outputs: - if isinstance(output, (tuple, list)): - outputs_flat.extend(output) - else: - outputs_flat.append(output) - - # Generate output names & axes - output_names = [f"output_{i}" for i in range(len(outputs_flat))] - output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)} - - # Create the aggregated axes representation - dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes) - return input_vars, output_names, dynamic_axes, tokens - - -def load_graph_from_args( - pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs -) -> Pipeline: - """ - Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model - - Args: - pipeline_name: The kind of pipeline to use (ner, question-answering, etc.) - framework: Not used anymore, only kept for BC - model: The model name which will be loaded by the pipeline - tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value - - Returns: Pipeline object - - """ - # If no tokenizer provided - if tokenizer is None: - tokenizer = model - - # Check pytorch is available - if not is_torch_available(): - raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.") - - print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})") - - # Allocate tokenizer and model - return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework="pt", model_kwargs=models_kwargs) - - -def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool): - """ - Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR - - Args: - nlp: The pipeline to be exported - opset: The actual version of the ONNX operator set to use - output: Path where will be stored the generated ONNX model - use_external_format: Split the model definition from its parameters to allow model bigger than 2GB - - Returns: - - """ - if not is_torch_available(): - raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.") - - import torch - from torch.onnx import export - - print(f"Using framework PyTorch: {torch.__version__}") - - with torch.no_grad(): - input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt") - ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names) - - export( - nlp.model, - model_args, - f=output.as_posix(), - input_names=ordered_input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - do_constant_folding=True, - opset_version=opset, - ) - - -def convert( - framework: str, - model: str, - output: Path, - opset: int, - tokenizer: Optional[str] = None, - use_external_format: bool = False, - pipeline_name: str = "feature-extraction", - **model_kwargs, -): - """ - Convert the pipeline object to the ONNX Intermediate Representation (IR) format - - Args: - framework: Not used anymore, only kept for BC - model: The name of the model to load for the pipeline - output: The path where the ONNX graph will be stored - opset: The actual version of the ONNX operator set to use - tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided - use_external_format: - Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only) - pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.) - model_kwargs: Keyword arguments to be forwarded to the model constructor - - Returns: - - """ - warnings.warn( - "The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of" - " Transformers", - FutureWarning, - ) - print(f"ONNX opset version set to: {opset}") - - # Load the pipeline - nlp = load_graph_from_args(pipeline_name, "pt", model, tokenizer, **model_kwargs) - - if not output.parent.exists(): - print(f"Creating folder {output.parent}") - makedirs(output.parent.as_posix()) - elif len(listdir(output.parent.as_posix())) > 0: - raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion") - - # Export the graph - convert_pytorch(nlp, opset, output, use_external_format) - - -def optimize(onnx_model_path: Path) -> Path: - """ - Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the - optimizations possible - - Args: - onnx_model_path: filepath where the model binary description is stored - - Returns: Path where the optimized model binary description has been saved - - """ - from onnxruntime import InferenceSession, SessionOptions - - # Generate model name with suffix "optimized" - opt_model_path = generate_identified_filename(onnx_model_path, "-optimized") - sess_option = SessionOptions() - sess_option.optimized_model_filepath = opt_model_path.as_posix() - _ = InferenceSession(onnx_model_path.as_posix(), sess_option) - - print(f"Optimized model has been written at {opt_model_path}: \N{HEAVY CHECK MARK}") - print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\") - - return opt_model_path - - -def quantize(onnx_model_path: Path) -> Path: - """ - Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU - - Args: - onnx_model_path: Path to location the exported ONNX model is stored - - Returns: The Path generated for the quantized - """ - import onnx - import onnxruntime - from onnx.onnx_pb import ModelProto - from onnxruntime.quantization import QuantizationMode - from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer - from onnxruntime.quantization.registry import IntegerOpsRegistry - - # Load the ONNX model - onnx_model = onnx.load(onnx_model_path.as_posix()) - - if parse(onnx.__version__) < parse("1.5.0"): - print( - "Models larger than 2GB will fail to quantize due to protobuf constraint.\n" - "Please upgrade to onnxruntime >= 1.5.0." - ) - - # Copy it - copy_model = ModelProto() - copy_model.CopyFrom(onnx_model) - - # Construct quantizer - # onnxruntime renamed input_qType to activation_qType in v1.13.1, so we - # check the onnxruntime version to ensure backward compatibility. - # See also: https://github.com/microsoft/onnxruntime/pull/12873 - if parse(onnxruntime.__version__) < parse("1.13.1"): - quantizer = ONNXQuantizer( - model=copy_model, - per_channel=False, - reduce_range=False, - mode=QuantizationMode.IntegerOps, - static=False, - weight_qType=True, - input_qType=False, - tensors_range=None, - nodes_to_quantize=None, - nodes_to_exclude=None, - op_types_to_quantize=list(IntegerOpsRegistry), - ) - else: - quantizer = ONNXQuantizer( - model=copy_model, - per_channel=False, - reduce_range=False, - mode=QuantizationMode.IntegerOps, - static=False, - weight_qType=True, - activation_qType=False, - tensors_range=None, - nodes_to_quantize=None, - nodes_to_exclude=None, - op_types_to_quantize=list(IntegerOpsRegistry), - ) - - # Quantize and export - quantizer.quantize_model() - - # Append "-quantized" at the end of the model's name - quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized") - - # Save model - print(f"Quantized model has been written at {quantized_model_path}: \N{HEAVY CHECK MARK}") - onnx.save_model(quantizer.model.model, quantized_model_path.as_posix()) - - return quantized_model_path - - -def verify(path: Path): - from onnxruntime import InferenceSession, SessionOptions - from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException - - print(f"Checking ONNX model loading from: {path} ...") - try: - onnx_options = SessionOptions() - _ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"]) - print(f"Model {path} correctly loaded: \N{HEAVY CHECK MARK}") - except RuntimeException as re: - print(f"Error while loading the model {re}: \N{HEAVY BALLOT X}") - - -if __name__ == "__main__": - parser = OnnxConverterArgumentParser() - args = parser.parse_args() - - # Make sure output is absolute path - args.output = Path(args.output).absolute() - - try: - print("\n====== Converting model to ONNX ======") - # Convert - convert( - args.framework, - args.model, - args.output, - args.opset, - args.tokenizer, - args.use_external_format, - args.pipeline, - ) - - if args.quantize: - # Ensure requirements for quantization on onnxruntime is met - check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION) - - print("\n====== Optimizing ONNX model ======") - - # Quantization works best when using the optimized version of the model - args.optimized_output = optimize(args.output) - - # Do the quantization on the right graph - args.quantized_output = quantize(args.optimized_output) - - # And verify - if args.check_loading: - print("\n====== Check exported ONNX model(s) ======") - verify(args.output) - - if hasattr(args, "optimized_output"): - verify(args.optimized_output) - - if hasattr(args, "quantized_output"): - verify(args.quantized_output) - - except Exception as e: - print(f"Error while converting the model: {e}") - exit(1) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index ff9bf6d7cd69..c9cdbdce97f4 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -729,7 +729,6 @@ def create_rng(self): # worker's generator, generated as the main seed + the worker's ID. # (https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading) # Only PyTorch DataLoader allows us to access the worker ID, and so we check for this. - # For other frameworks, we will throw an error. import torch worker_info = torch.utils.data.get_worker_info() diff --git a/src/transformers/data/datasets/glue.py b/src/transformers/data/datasets/glue.py index d8db0dfebac1..808eb1e50578 100644 --- a/src/transformers/data/datasets/glue.py +++ b/src/transformers/data/datasets/glue.py @@ -69,10 +69,6 @@ class Split(Enum): class GlueDataset(Dataset): - """ - This will be superseded by a framework-agnostic approach soon. - """ - args: GlueDataTrainingArguments output_mode: str features: list[InputFeatures] diff --git a/src/transformers/data/datasets/language_modeling.py b/src/transformers/data/datasets/language_modeling.py index 07250ef3cb54..85d7e5360df3 100644 --- a/src/transformers/data/datasets/language_modeling.py +++ b/src/transformers/data/datasets/language_modeling.py @@ -38,10 +38,6 @@ class TextDataset(Dataset): - """ - This will be superseded by a framework-agnostic approach soon. - """ - def __init__( self, tokenizer: PreTrainedTokenizer, @@ -111,10 +107,6 @@ def __getitem__(self, i) -> torch.Tensor: class LineByLineTextDataset(Dataset): - """ - This will be superseded by a framework-agnostic approach soon. - """ - def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int): warnings.warn( DEPRECATION_WARNING.format( @@ -144,10 +136,6 @@ def __getitem__(self, i) -> dict[str, torch.tensor]: class LineByLineWithRefDataset(Dataset): - """ - This will be superseded by a framework-agnostic approach soon. - """ - def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str): warnings.warn( DEPRECATION_WARNING.format( @@ -344,10 +332,6 @@ def __getitem__(self, i) -> dict[str, torch.tensor]: class TextDatasetForNextSentencePrediction(Dataset): - """ - This will be superseded by a framework-agnostic approach soon. - """ - def __init__( self, tokenizer: PreTrainedTokenizer, diff --git a/src/transformers/data/datasets/squad.py b/src/transformers/data/datasets/squad.py index fdee571e249b..d96d8224d6b9 100644 --- a/src/transformers/data/datasets/squad.py +++ b/src/transformers/data/datasets/squad.py @@ -107,10 +107,6 @@ class Split(Enum): class SquadDataset(Dataset): - """ - This will be superseded by a framework-agnostic approach soon. - """ - args: SquadDataTrainingArguments features: list[SquadFeatures] mode: Split diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index a83f6254e044..2aba3d549719 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -556,8 +556,7 @@ def center_to_corners_format(bboxes_center: TensorType) -> TensorType: corners format: contains the coordinates for the top-left and bottom-right corners of the box (top_left_x, top_left_y, bottom_right_x, bottom_right_y) """ - # Function is used during model forward pass, so we use the input framework if possible, without - # converting to numpy + # Function is used during model forward pass, so we use torch if relevant, without converting to numpy if is_torch_tensor(bboxes_center): return _center_to_corners_format_torch(bboxes_center) elif isinstance(bboxes_center, np.ndarray): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d5555f92b0b2..c4cc812f759a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2073,13 +2073,6 @@ def dummy_inputs(self) -> dict[str, torch.Tensor]: """ return {"input_ids": torch.tensor(DUMMY_INPUTS)} - @property - def framework(self) -> str: - """ - :str: Identifies that this is a PyTorch model. - """ - return "pt" - def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) # For BC we keep the original `config_class` definition in case @@ -3718,9 +3711,6 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """ Activates gradient checkpointing for the current model. - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 @@ -3781,9 +3771,6 @@ def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointin def gradient_checkpointing_disable(self): """ Deactivates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". """ if self.supports_gradient_checkpointing: # For old GC format (transformers < 4.35.0) for models that live on the Hub @@ -3805,9 +3792,6 @@ def gradient_checkpointing_disable(self): def is_gradient_checkpointing(self) -> bool: """ Whether gradient checkpointing is activated for this model or not. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". """ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) diff --git a/src/transformers/models/aya_vision/processing_aya_vision.py b/src/transformers/models/aya_vision/processing_aya_vision.py index 7045c967046d..aaede4e8e80e 100644 --- a/src/transformers/models/aya_vision/processing_aya_vision.py +++ b/src/transformers/models/aya_vision/processing_aya_vision.py @@ -160,10 +160,8 @@ def __call__( `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/beit/image_processing_beit.py b/src/transformers/models/beit/image_processing_beit.py index c25880bcfada..a93051b00b25 100644 --- a/src/transformers/models/beit/image_processing_beit.py +++ b/src/transformers/models/beit/image_processing_beit.py @@ -459,7 +459,7 @@ def preprocess( def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ - Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Args: outputs ([`BeitForSemanticSegmentation`]): @@ -473,7 +473,6 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[lis segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ - # TODO: add support for other frameworks logits = outputs.logits # Resize logits and compute semantic segmentation maps diff --git a/src/transformers/models/beit/image_processing_beit_fast.py b/src/transformers/models/beit/image_processing_beit_fast.py index e10dc552cf37..7a55543dee62 100644 --- a/src/transformers/models/beit/image_processing_beit_fast.py +++ b/src/transformers/models/beit/image_processing_beit_fast.py @@ -186,7 +186,7 @@ def _preprocess( def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ - Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Args: outputs ([`BeitForSemanticSegmentation`]): @@ -200,7 +200,6 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[lis segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ - # TODO: add support for other frameworks logits = outputs.logits # Resize logits and compute semantic segmentation maps diff --git a/src/transformers/models/blip/processing_blip.py b/src/transformers/models/blip/processing_blip.py index 5cc4334a974c..4ac741f84f46 100644 --- a/src/transformers/models/blip/processing_blip.py +++ b/src/transformers/models/blip/processing_blip.py @@ -86,10 +86,8 @@ def __call__( `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. """ if images is None and text is None: raise ValueError("You have to specify either images or text.") diff --git a/src/transformers/models/blip_2/processing_blip_2.py b/src/transformers/models/blip_2/processing_blip_2.py index a1c89f7f460a..71f79583c77e 100644 --- a/src/transformers/models/blip_2/processing_blip_2.py +++ b/src/transformers/models/blip_2/processing_blip_2.py @@ -100,10 +100,8 @@ def __call__( `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. """ if images is None and text is None: raise ValueError("You have to specify either images or text.") diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index d481a62b6fc6..bf4441c00a2e 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -114,10 +114,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/clipseg/processing_clipseg.py b/src/transformers/models/clipseg/processing_clipseg.py index 0af4c8ee12fc..e8cd47b0aa54 100644 --- a/src/transformers/models/clipseg/processing_clipseg.py +++ b/src/transformers/models/clipseg/processing_clipseg.py @@ -78,10 +78,8 @@ def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=No return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: diff --git a/src/transformers/models/cohere/tokenization_cohere_fast.py b/src/transformers/models/cohere/tokenization_cohere_fast.py index fd240b978480..8072cbe7c17c 100644 --- a/src/transformers/models/cohere/tokenization_cohere_fast.py +++ b/src/transformers/models/cohere/tokenization_cohere_fast.py @@ -276,10 +276,8 @@ def apply_tool_use_template( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.Tensor` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. return_dict (`bool`, *optional*, defaults to `False`): Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. **tokenizer_kwargs: Additional kwargs to pass to the tokenizer. @@ -424,10 +422,8 @@ def apply_grounded_generation_template( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.Tensor` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. return_dict (`bool`, *optional*, defaults to `False`): Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. **tokenizer_kwargs: Additional kwargs to pass to the tokenizer. diff --git a/src/transformers/models/cohere2_vision/processing_cohere2_vision.py b/src/transformers/models/cohere2_vision/processing_cohere2_vision.py index b72e1512ead9..cde77af658bc 100644 --- a/src/transformers/models/cohere2_vision/processing_cohere2_vision.py +++ b/src/transformers/models/cohere2_vision/processing_cohere2_vision.py @@ -103,10 +103,8 @@ def __call__( `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/colpali/modular_colpali.py b/src/transformers/models/colpali/modular_colpali.py index cf28475f4b3c..0c932a732258 100644 --- a/src/transformers/models/colpali/modular_colpali.py +++ b/src/transformers/models/colpali/modular_colpali.py @@ -117,10 +117,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -224,10 +222,8 @@ def process_images( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -259,10 +255,8 @@ def process_queries( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/colpali/processing_colpali.py b/src/transformers/models/colpali/processing_colpali.py index b3f758a00006..5d77eced20d9 100644 --- a/src/transformers/models/colpali/processing_colpali.py +++ b/src/transformers/models/colpali/processing_colpali.py @@ -158,10 +158,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -292,10 +290,8 @@ def process_images( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -327,10 +323,8 @@ def process_queries( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index f3ae79abf6fa..a9a1f8ce3e1e 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -120,10 +120,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/colqwen2/processing_colqwen2.py b/src/transformers/models/colqwen2/processing_colqwen2.py index 1609f6e182da..372ce542d580 100644 --- a/src/transformers/models/colqwen2/processing_colqwen2.py +++ b/src/transformers/models/colqwen2/processing_colqwen2.py @@ -121,10 +121,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -277,10 +275,8 @@ def process_images( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -312,10 +308,8 @@ def process_queries( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py index e68cab454929..cf506b834918 100644 --- a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py +++ b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py @@ -18,7 +18,7 @@ import pathlib from collections import defaultdict from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import numpy as np @@ -189,23 +189,6 @@ def get_image_size_for_max_height_width( return new_height, new_width -# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn -def get_numpy_to_framework_fn(arr) -> Callable: - """ - Returns a function that converts a numpy array to the framework of the input array. - - Args: - arr (`np.ndarray`): The array to convert. - """ - if isinstance(arr, np.ndarray): - return np.array - if is_torch_available() and is_torch_tensor(arr): - import torch - - return torch.tensor - raise ValueError(f"Cannot convert arrays of type {type(arr)}") - - # Copied from transformers.models.detr.image_processing_detr.safe_squeeze def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: """ @@ -1510,11 +1493,9 @@ def preprocess( return encoded_inputs - # POSTPROCESSING METHODS - TODO: add support for other frameworks def post_process(self, outputs, target_sizes): """ Converts the output of [`ConditionalDetrForObjectDetection`] into the format expected by the Pascal VOC format (xmin, ymin, xmax, ymax). - Only supports PyTorch. Args: outputs ([`ConditionalDetrObjectDetectionOutput`]): diff --git a/src/transformers/models/csm/processing_csm.py b/src/transformers/models/csm/processing_csm.py index 0f929f6a2a0c..7e16ecbb6001 100644 --- a/src/transformers/models/csm/processing_csm.py +++ b/src/transformers/models/csm/processing_csm.py @@ -226,10 +226,8 @@ def __call__( The ratio of audio frames to keep for the depth decoder labels. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/deepseek_vl/modular_deepseek_vl.py b/src/transformers/models/deepseek_vl/modular_deepseek_vl.py index 5bfc0ae7d74c..ce2f9be16ae6 100644 --- a/src/transformers/models/deepseek_vl/modular_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/modular_deepseek_vl.py @@ -261,10 +261,8 @@ def __call__( tensor. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/deepseek_vl/processing_deepseek_vl.py b/src/transformers/models/deepseek_vl/processing_deepseek_vl.py index 26d59d85a295..ddeb4f799ee1 100644 --- a/src/transformers/models/deepseek_vl/processing_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/processing_deepseek_vl.py @@ -92,10 +92,8 @@ def __call__( tensor. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py index d97b00f7fbd2..5149f7f42178 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py @@ -959,10 +959,8 @@ def __call__( tensor. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py index 538fea5a6b32..d20fa495f9b8 100644 --- a/src/transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py @@ -92,10 +92,8 @@ def __call__( tensor. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py index 27998d605502..ef028eda1ed1 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py @@ -18,7 +18,7 @@ import pathlib from collections import defaultdict from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import numpy as np @@ -187,23 +187,6 @@ def get_image_size_for_max_height_width( return new_height, new_width -# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn -def get_numpy_to_framework_fn(arr) -> Callable: - """ - Returns a function that converts a numpy array to the framework of the input array. - - Args: - arr (`np.ndarray`): The array to convert. - """ - if isinstance(arr, np.ndarray): - return np.array - if is_torch_available() and is_torch_tensor(arr): - import torch - - return torch.tensor - raise ValueError(f"Cannot convert arrays of type {type(arr)}") - - # Copied from transformers.models.detr.image_processing_detr.safe_squeeze def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: """ @@ -1508,11 +1491,10 @@ def preprocess( return encoded_inputs - # POSTPROCESSING METHODS - TODO: add support for other frameworks def post_process(self, outputs, target_sizes): """ Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x, - top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch. + top_left_y, bottom_right_x, bottom_right_y) format. Args: outputs ([`DeformableDetrObjectDetectionOutput`]): diff --git a/src/transformers/models/deprecated/deta/image_processing_deta.py b/src/transformers/models/deprecated/deta/image_processing_deta.py index 15220603bb40..b54e07d240ea 100644 --- a/src/transformers/models/deprecated/deta/image_processing_deta.py +++ b/src/transformers/models/deprecated/deta/image_processing_deta.py @@ -16,7 +16,7 @@ import pathlib from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import numpy as np @@ -51,7 +51,6 @@ ) from ....utils import ( is_torch_available, - is_torch_tensor, is_torchvision_available, is_vision_available, logging, @@ -177,23 +176,6 @@ def get_image_size_for_max_height_width( return new_height, new_width -def get_numpy_to_framework_fn(arr) -> Callable: - """ - Returns a function that converts a numpy array to the framework of the input array. - - Args: - arr (`np.ndarray`): The array to convert. - """ - if isinstance(arr, np.ndarray): - return np.array - if is_torch_available() and is_torch_tensor(arr): - import torch - - return torch.tensor - - raise ValueError(f"Cannot convert arrays of type {type(arr)}") - - def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: """ Squeezes an array, but only if the axis specified has dim 1. diff --git a/src/transformers/models/detr/image_processing_detr.py b/src/transformers/models/detr/image_processing_detr.py index 80287942b5f9..7a2e67f83de6 100644 --- a/src/transformers/models/detr/image_processing_detr.py +++ b/src/transformers/models/detr/image_processing_detr.py @@ -18,7 +18,7 @@ import pathlib from collections import defaultdict from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import numpy as np @@ -185,22 +185,6 @@ def get_resize_output_image_size( return get_size_with_aspect_ratio(image_size, size, max_size) -def get_numpy_to_framework_fn(arr) -> Callable: - """ - Returns a function that converts a numpy array to the framework of the input array. - - Args: - arr (`np.ndarray`): The array to convert. - """ - if isinstance(arr, np.ndarray): - return np.array - if is_torch_available() and is_torch_tensor(arr): - import torch - - return torch.tensor - raise ValueError(f"Cannot convert arrays of type {type(arr)}") - - def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: """ Squeezes an array, but only if the axis specified has dim 1. @@ -623,7 +607,6 @@ def resize_annotation( return new_annotation -# TODO - (Amy) make compatible with other frameworks def binary_mask_to_rle(mask): """ Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format. @@ -646,7 +629,6 @@ def binary_mask_to_rle(mask): return list(runs) -# TODO - (Amy) make compatible with other frameworks def convert_segmentation_to_rle(segmentation): """ Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format. @@ -1483,12 +1465,11 @@ def preprocess( return encoded_inputs - # POSTPROCESSING METHODS - TODO: add support for other frameworks # inspired by https://github.com/facebookresearch/detr/blob/master/models/detr.py#L258 def post_process(self, outputs, target_sizes): """ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, - bottom_right_x, bottom_right_y) format. Only supports PyTorch. + bottom_right_x, bottom_right_y) format. Args: outputs ([`DetrObjectDetectionOutput`]): diff --git a/src/transformers/models/detr/image_processing_detr_fast.py b/src/transformers/models/detr/image_processing_detr_fast.py index 96a89a98074c..f30ebfa41859 100644 --- a/src/transformers/models/detr/image_processing_detr_fast.py +++ b/src/transformers/models/detr/image_processing_detr_fast.py @@ -725,7 +725,7 @@ def _preprocess( def post_process(self, outputs, target_sizes): """ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, - bottom_right_x, bottom_right_y) format. Only supports PyTorch. + bottom_right_x, bottom_right_y) format. Args: outputs ([`DetrObjectDetectionOutput`]): diff --git a/src/transformers/models/dpt/image_processing_dpt.py b/src/transformers/models/dpt/image_processing_dpt.py index 9b28950d2ded..189b34e9fb7b 100644 --- a/src/transformers/models/dpt/image_processing_dpt.py +++ b/src/transformers/models/dpt/image_processing_dpt.py @@ -591,7 +591,7 @@ def preprocess( # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->DPT def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ - Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Args: outputs ([`DPTForSemanticSegmentation`]): @@ -605,7 +605,6 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[lis segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ - # TODO: add support for other frameworks logits = outputs.logits # Resize logits and compute semantic segmentation maps diff --git a/src/transformers/models/dpt/image_processing_dpt_fast.py b/src/transformers/models/dpt/image_processing_dpt_fast.py index d4848c50653c..3e80ad7943db 100644 --- a/src/transformers/models/dpt/image_processing_dpt_fast.py +++ b/src/transformers/models/dpt/image_processing_dpt_fast.py @@ -256,7 +256,7 @@ def _preprocess( def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ - Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Args: outputs ([`DPTForSemanticSegmentation`]): @@ -270,7 +270,6 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[lis segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ - # TODO: add support for other frameworks logits = outputs.logits # Resize logits and compute semantic segmentation maps diff --git a/src/transformers/models/emu3/processing_emu3.py b/src/transformers/models/emu3/processing_emu3.py index 67ccab795733..ef2681d2385b 100644 --- a/src/transformers/models/emu3/processing_emu3.py +++ b/src/transformers/models/emu3/processing_emu3.py @@ -117,10 +117,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/florence2/modular_florence2.py b/src/transformers/models/florence2/modular_florence2.py index f8732257f102..102cff29d800 100644 --- a/src/transformers/models/florence2/modular_florence2.py +++ b/src/transformers/models/florence2/modular_florence2.py @@ -363,10 +363,8 @@ def __call__( `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/florence2/processing_florence2.py b/src/transformers/models/florence2/processing_florence2.py index 91b63e9da7db..5ae0f4828bc1 100644 --- a/src/transformers/models/florence2/processing_florence2.py +++ b/src/transformers/models/florence2/processing_florence2.py @@ -171,10 +171,8 @@ def __call__( `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 7c400edc51c3..526abd1138b1 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -1562,10 +1562,8 @@ def __call__( tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/glm4v/processing_glm4v.py b/src/transformers/models/glm4v/processing_glm4v.py index 817da3630d52..511869bbcafd 100644 --- a/src/transformers/models/glm4v/processing_glm4v.py +++ b/src/transformers/models/glm4v/processing_glm4v.py @@ -117,10 +117,8 @@ def __call__( tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/got_ocr2/processing_got_ocr2.py b/src/transformers/models/got_ocr2/processing_got_ocr2.py index 16c062ec63ad..35df3b5a3f05 100644 --- a/src/transformers/models/got_ocr2/processing_got_ocr2.py +++ b/src/transformers/models/got_ocr2/processing_got_ocr2.py @@ -177,10 +177,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/grounding_dino/image_processing_grounding_dino.py b/src/transformers/models/grounding_dino/image_processing_grounding_dino.py index 2910ea471059..737cf2e670ee 100644 --- a/src/transformers/models/grounding_dino/image_processing_grounding_dino.py +++ b/src/transformers/models/grounding_dino/image_processing_grounding_dino.py @@ -18,7 +18,7 @@ import pathlib from collections import defaultdict from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np @@ -196,23 +196,6 @@ def get_image_size_for_max_height_width( return new_height, new_width -# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn -def get_numpy_to_framework_fn(arr) -> Callable: - """ - Returns a function that converts a numpy array to the framework of the input array. - - Args: - arr (`np.ndarray`): The array to convert. - """ - if isinstance(arr, np.ndarray): - return np.array - if is_torch_available() and is_torch_tensor(arr): - import torch - - return torch.tensor - raise ValueError(f"Cannot convert arrays of type {type(arr)}") - - # Copied from transformers.models.detr.image_processing_detr.safe_squeeze def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: """ diff --git a/src/transformers/models/internvl/processing_internvl.py b/src/transformers/models/internvl/processing_internvl.py index a13457886baf..12e6a6163ba8 100644 --- a/src/transformers/models/internvl/processing_internvl.py +++ b/src/transformers/models/internvl/processing_internvl.py @@ -180,10 +180,8 @@ def __call__( The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/janus/processing_janus.py b/src/transformers/models/janus/processing_janus.py index 2de97400043f..c2413e705756 100644 --- a/src/transformers/models/janus/processing_janus.py +++ b/src/transformers/models/janus/processing_janus.py @@ -102,10 +102,8 @@ def __call__( tensor. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/llama4/processing_llama4.py b/src/transformers/models/llama4/processing_llama4.py index ce590bc6f40b..47a0b4cd99fb 100644 --- a/src/transformers/models/llama4/processing_llama4.py +++ b/src/transformers/models/llama4/processing_llama4.py @@ -159,10 +159,8 @@ def __call__( `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 63c07c20cbb9..398bd9d8d065 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -115,10 +115,8 @@ def __call__( `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index b9dbc6650b63..e858a1784254 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -143,10 +143,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py index eb6e6388bff4..665173d59b95 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py @@ -478,7 +478,7 @@ def preprocess( # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileNetV2 def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ - Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Args: outputs ([`MobileNetV2ForSemanticSegmentation`]): @@ -492,7 +492,6 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[lis segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ - # TODO: add support for other frameworks logits = outputs.logits # Resize logits and compute semantic segmentation maps diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py index 97ca39da78bf..948f9e96d7d9 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py @@ -206,7 +206,7 @@ def _preprocess( # Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.post_process_semantic_segmentation with Beit->MobileNetV2 def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ - Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Args: outputs ([`MobileNetV2ForSemanticSegmentation`]): @@ -220,7 +220,6 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[lis segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ - # TODO: add support for other frameworks logits = outputs.logits # Resize logits and compute semantic segmentation maps diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit.py b/src/transformers/models/mobilevit/image_processing_mobilevit.py index 5411023c3104..bbf76ff04023 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit.py @@ -473,7 +473,7 @@ def preprocess( # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileViT def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ - Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Args: outputs ([`MobileViTForSemanticSegmentation`]): @@ -487,7 +487,6 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[lis segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ - # TODO: add support for other frameworks logits = outputs.logits # Resize logits and compute semantic segmentation maps diff --git a/src/transformers/models/owlv2/image_processing_owlv2_fast.py b/src/transformers/models/owlv2/image_processing_owlv2_fast.py index 70441feba3c2..c17a45b6e427 100644 --- a/src/transformers/models/owlv2/image_processing_owlv2_fast.py +++ b/src/transformers/models/owlv2/image_processing_owlv2_fast.py @@ -86,7 +86,6 @@ def post_process(self, outputs, target_sizes): `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image in the batch as predicted by the model. """ - # TODO: (amy) add support for other frameworks warnings.warn( "`post_process` is deprecated and will be removed in v5 of Transformers, please use" " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", diff --git a/src/transformers/models/owlvit/image_processing_owlvit.py b/src/transformers/models/owlvit/image_processing_owlvit.py index cc9c6cfdeaa8..58253073d5d8 100644 --- a/src/transformers/models/owlvit/image_processing_owlvit.py +++ b/src/transformers/models/owlvit/image_processing_owlvit.py @@ -461,7 +461,6 @@ def post_process(self, outputs, target_sizes): `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image in the batch as predicted by the model. """ - # TODO: (amy) add support for other frameworks warnings.warn( "`post_process` is deprecated and will be removed in v5 of Transformers, please use" " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", diff --git a/src/transformers/models/owlvit/image_processing_owlvit_fast.py b/src/transformers/models/owlvit/image_processing_owlvit_fast.py index 1e458f964a04..53d94313ece9 100644 --- a/src/transformers/models/owlvit/image_processing_owlvit_fast.py +++ b/src/transformers/models/owlvit/image_processing_owlvit_fast.py @@ -65,7 +65,6 @@ def post_process(self, outputs, target_sizes): `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image in the batch as predicted by the model. """ - # TODO: (amy) add support for other frameworks warnings.warn( "`post_process` is deprecated and will be removed in v5 of Transformers, please use" " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 242627a0eb71..7bf7fe403d5f 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -192,10 +192,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. suffix (`str`, `list[str]`, `list[list[str]]`): The suffixes or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md for more information. If your prompt is " What is on the image", the suffix corresponds to the expected prediction "a cow sitting on a bench". diff --git a/src/transformers/models/perception_lm/processing_perception_lm.py b/src/transformers/models/perception_lm/processing_perception_lm.py index f61c54554d32..35f0fef6c4ca 100644 --- a/src/transformers/models/perception_lm/processing_perception_lm.py +++ b/src/transformers/models/perception_lm/processing_perception_lm.py @@ -110,10 +110,8 @@ def __call__( The video or batch of videos to be processed. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index bb868156fb40..bf4eb9307c72 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -142,10 +142,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index b59644c37df9..817d9708d1d6 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -908,10 +908,8 @@ def __call__( tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index b357ba850deb..0b2fc3dbfc38 100644 --- a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -121,10 +121,8 @@ def __call__( tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py index 5bbbf6ac1aec..b237cb6079fb 100644 --- a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py @@ -116,10 +116,8 @@ def __call__( tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr.py b/src/transformers/models/rt_detr/image_processing_rt_detr.py index 3c0e994c374a..4603a21095e7 100644 --- a/src/transformers/models/rt_detr/image_processing_rt_detr.py +++ b/src/transformers/models/rt_detr/image_processing_rt_detr.py @@ -16,7 +16,7 @@ import pathlib from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import numpy as np @@ -51,7 +51,6 @@ from ...utils import ( filter_out_non_signature_kwargs, is_torch_available, - is_torch_tensor, logging, requires_backends, ) @@ -170,23 +169,6 @@ def get_image_size_for_max_height_width( return new_height, new_width -# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn -def get_numpy_to_framework_fn(arr) -> Callable: - """ - Returns a function that converts a numpy array to the framework of the input array. - - Args: - arr (`np.ndarray`): The array to convert. - """ - if isinstance(arr, np.ndarray): - return np.array - if is_torch_available() and is_torch_tensor(arr): - import torch - - return torch.tensor - raise ValueError(f"Cannot convert arrays of type {type(arr)}") - - # Copied from transformers.models.detr.image_processing_detr.safe_squeeze def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: """ diff --git a/src/transformers/models/segformer/image_processing_segformer.py b/src/transformers/models/segformer/image_processing_segformer.py index 46e66babe4de..4025b59e7ebb 100644 --- a/src/transformers/models/segformer/image_processing_segformer.py +++ b/src/transformers/models/segformer/image_processing_segformer.py @@ -426,7 +426,7 @@ def preprocess( # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->Segformer def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ - Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Args: outputs ([`SegformerForSemanticSegmentation`]): @@ -440,7 +440,6 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[lis segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ - # TODO: add support for other frameworks logits = outputs.logits # Resize logits and compute semantic segmentation maps diff --git a/src/transformers/models/segformer/image_processing_segformer_fast.py b/src/transformers/models/segformer/image_processing_segformer_fast.py index da4bef3e9ee8..dc18283136e1 100644 --- a/src/transformers/models/segformer/image_processing_segformer_fast.py +++ b/src/transformers/models/segformer/image_processing_segformer_fast.py @@ -196,7 +196,7 @@ def _preprocess( def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ - Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Args: outputs ([`SegformerForSemanticSegmentation`]): @@ -210,7 +210,6 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[lis segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ - # TODO: add support for other frameworks logits = outputs.logits # Resize logits and compute semantic segmentation maps diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index fc7a396853a8..a6f826fa72a3 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -133,10 +133,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/vitpose/image_processing_vitpose.py b/src/transformers/models/vitpose/image_processing_vitpose.py index 5bdefe3064bb..1fffc0e389c5 100644 --- a/src/transformers/models/vitpose/image_processing_vitpose.py +++ b/src/transformers/models/vitpose/image_processing_vitpose.py @@ -465,10 +465,8 @@ def preprocess( return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/voxtral/processing_voxtral.py b/src/transformers/models/voxtral/processing_voxtral.py index 0cf2d121f9da..1166c9636307 100644 --- a/src/transformers/models/voxtral/processing_voxtral.py +++ b/src/transformers/models/voxtral/processing_voxtral.py @@ -251,10 +251,8 @@ def __call__( `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/yolos/image_processing_yolos.py b/src/transformers/models/yolos/image_processing_yolos.py index 8430b4cc6d67..50da604db8d3 100644 --- a/src/transformers/models/yolos/image_processing_yolos.py +++ b/src/transformers/models/yolos/image_processing_yolos.py @@ -16,7 +16,7 @@ import pathlib from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import numpy as np @@ -213,23 +213,6 @@ def get_resize_output_image_size( return get_size_with_aspect_ratio(image_size, size, max_size) -# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn -def get_numpy_to_framework_fn(arr) -> Callable: - """ - Returns a function that converts a numpy array to the framework of the input array. - - Args: - arr (`np.ndarray`): The array to convert. - """ - if isinstance(arr, np.ndarray): - return np.array - if is_torch_available() and is_torch_tensor(arr): - import torch - - return torch.tensor - raise ValueError(f"Cannot convert arrays of type {type(arr)}") - - # Copied from transformers.models.detr.image_processing_detr.safe_squeeze def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: """ @@ -1419,12 +1402,11 @@ def preprocess( return encoded_inputs - # POSTPROCESSING METHODS - TODO: add support for other frameworks # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process with Detr->Yolos def post_process(self, outputs, target_sizes): """ Converts the raw output of [`YolosForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, - bottom_right_x, bottom_right_y) format. Only supports PyTorch. + bottom_right_x, bottom_right_y) format. Args: outputs ([`YolosObjectDetectionOutput`]): diff --git a/src/transformers/pipelines/audio_classification.py b/src/transformers/pipelines/audio_classification.py index 9f4822e2b2be..58fb3ab4fcab 100644 --- a/src/transformers/pipelines/audio_classification.py +++ b/src/transformers/pipelines/audio_classification.py @@ -103,9 +103,6 @@ def __init__(self, *args, **kwargs): kwargs["top_k"] = 5 super().__init__(*args, **kwargs) - if self.framework != "pt": - raise ValueError(f"The {self.__class__} is only available in PyTorch.") - self.check_model_type(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES) def __call__(self, inputs: Union[np.ndarray, bytes, str, dict], **kwargs: Any) -> list[dict[str, Any]]: diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 35bd35fc6b29..960bc00f4c51 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -587,7 +587,7 @@ def postprocess( key = "logits" if self.type == "ctc_with_lm" else "tokens" stride = None for outputs in model_outputs: - if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16): + if outputs[key].dtype in (torch.bfloat16, torch.float16): items = outputs[key].to(torch.float32).numpy() else: items = outputs[key].numpy() diff --git a/src/transformers/pipelines/depth_estimation.py b/src/transformers/pipelines/depth_estimation.py index 588cee770639..36bbe46b4e3e 100644 --- a/src/transformers/pipelines/depth_estimation.py +++ b/src/transformers/pipelines/depth_estimation.py @@ -115,9 +115,8 @@ def _sanitize_parameters(self, timeout=None, parameters=None, **kwargs): def preprocess(self, image, timeout=None): image = load_image(image, timeout) - model_inputs = self.image_processor(images=image, return_tensors=self.framework) - if self.framework == "pt": - model_inputs = model_inputs.to(self.dtype) + model_inputs = self.image_processor(images=image, return_tensors="pt") + model_inputs = model_inputs.to(self.dtype) model_inputs["target_size"] = image.size[::-1] return model_inputs diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py index 49a45b5a7f5c..eb5ec7d0d362 100644 --- a/src/transformers/pipelines/fill_mask.py +++ b/src/transformers/pipelines/fill_mask.py @@ -109,7 +109,7 @@ def preprocess( self, inputs, return_tensors=None, tokenizer_kwargs=None, **preprocess_parameters ) -> dict[str, GenericTensor]: if return_tensors is None: - return_tensors = self.framework + return_tensors = "pt" if tokenizer_kwargs is None: tokenizer_kwargs = {} diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 2564d53ba1d7..d42ed96213a6 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -396,7 +396,7 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **p inputs.messages, add_generation_prompt=not continue_final_message, continue_final_message=continue_final_message, - return_tensors=self.framework, + return_tensors="pt", tokenize=True, return_dict=True, ) @@ -415,7 +415,7 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **p # if batched text inputs, we set padding to True unless specified otherwise if isinstance(text, (list, tuple)) and len(text) > 1: processing_kwargs.setdefault("padding", True) - model_inputs = self.processor(images=images, text=text, return_tensors=self.framework, **processing_kwargs).to( + model_inputs = self.processor(images=images, text=text, return_tensors="pt", **processing_kwargs).to( dtype=self.dtype ) diff --git a/src/transformers/pipelines/image_to_image.py b/src/transformers/pipelines/image_to_image.py index d469024bff17..094a511449d6 100644 --- a/src/transformers/pipelines/image_to_image.py +++ b/src/transformers/pipelines/image_to_image.py @@ -130,8 +130,7 @@ def _forward(self, model_inputs): def preprocess(self, image, timeout=None): image = load_image(image, timeout=timeout) inputs = self.image_processor(images=[image], return_tensors="pt") - if self.framework == "pt": - inputs = inputs.to(self.dtype) + inputs = inputs.to(self.dtype) return inputs def postprocess(self, model_outputs): diff --git a/src/transformers/pipelines/keypoint_matching.py b/src/transformers/pipelines/keypoint_matching.py index 6878f40ad985..1e0d57d254e0 100644 --- a/src/transformers/pipelines/keypoint_matching.py +++ b/src/transformers/pipelines/keypoint_matching.py @@ -79,8 +79,6 @@ class KeypointMatchingPipeline(Pipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) requires_backends(self, "vision") - if self.framework != "pt": - raise ValueError("Keypoint matching pipeline only supports PyTorch (framework='pt').") def _sanitize_parameters(self, threshold=None, timeout=None): preprocess_params = {} @@ -146,7 +144,7 @@ def __call__( def preprocess(self, images, timeout=None): images = [load_image(image, timeout=timeout) for image in images] - model_inputs = self.image_processor(images=images, return_tensors=self.framework) + model_inputs = self.image_processor(images=images, return_tensors="pt") model_inputs = model_inputs.to(self.dtype) target_sizes = [image.size for image in images] preprocess_outputs = {"model_inputs": model_inputs, "target_sizes": target_sizes} diff --git a/src/transformers/pipelines/mask_generation.py b/src/transformers/pipelines/mask_generation.py index 3a65fdff617a..f7354807afa2 100644 --- a/src/transformers/pipelines/mask_generation.py +++ b/src/transformers/pipelines/mask_generation.py @@ -94,9 +94,6 @@ def __init__(self, **kwargs): requires_backends(self, "vision") requires_backends(self, "torch") - if self.framework != "pt": - raise ValueError(f"The {self.__class__} is only available in PyTorch.") - self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) def _sanitize_parameters(self, **kwargs): @@ -205,26 +202,24 @@ def preprocess( image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor ) model_inputs = self.image_processor(images=cropped_images, return_tensors="pt") - if self.framework == "pt": - model_inputs = model_inputs.to(self.dtype) + model_inputs = model_inputs.to(self.dtype) with self.device_placement(): - if self.framework == "pt": - inference_context = self.get_inference_context() - with inference_context(): - model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device) - embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values")) - - # Handle both SAM (single tensor) and SAM-HQ (tuple) outputs - if isinstance(embeddings, tuple): - image_embeddings, intermediate_embeddings = embeddings - model_inputs["intermediate_embeddings"] = intermediate_embeddings - else: - image_embeddings = embeddings - # TODO: Identifying the model by the type of its returned embeddings is brittle. - # Consider using a more robust method for distinguishing model types here. - - model_inputs["image_embeddings"] = image_embeddings + inference_context = self.get_inference_context() + with inference_context(): + model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device) + embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values")) + + # Handle both SAM (single tensor) and SAM-HQ (tuple) outputs + if isinstance(embeddings, tuple): + image_embeddings, intermediate_embeddings = embeddings + model_inputs["intermediate_embeddings"] = intermediate_embeddings + else: + image_embeddings = embeddings + # TODO: Identifying the model by the type of its returned embeddings is brittle. + # Consider using a more robust method for distinguishing model types here. + + model_inputs["image_embeddings"] = image_embeddings n_points = grid_points.shape[1] points_per_batch = points_per_batch if points_per_batch is not None else n_points diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index c9da04d37154..eb7e0bce8a34 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -169,7 +169,7 @@ def __call__(self, *args: Union[str, list[str]], **kwargs: Any) -> list[dict[str max_length instead of throwing an error down the line. generate_kwargs: Additional keyword arguments to pass along to the generate method of the model (see the generate method - corresponding to your framework [here](./text_generation)). + [here](./text_generation)). Return: A list or a list of list of `dict`: Each result comes as a dictionary with the following keys: @@ -273,7 +273,7 @@ def __call__(self, *args, **kwargs): Whether or not to clean up the potential extra spaces in the text output. generate_kwargs: Additional keyword arguments to pass along to the generate method of the model (see the generate method - corresponding to your framework [here](./text_generation)). + [here](./text_generation)). Return: A list or a list of list of `dict`: Each result comes as a dictionary with the following keys: @@ -380,7 +380,7 @@ def __call__(self, *args, **kwargs): for single pair translation models generate_kwargs: Additional keyword arguments to pass along to the generate method of the model (see the generate method - corresponding to your framework [here](./text_generation)). + [here](./text_generation)). Return: A list or a list of list of `dict`: Each result comes as a dictionary with the following keys: diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 515736076892..d43695b37399 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -60,7 +60,7 @@ class TextToAudioPipeline(Pipeline): ```python >>> from transformers import pipeline - >>> music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt") + >>> music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small") >>> # diversify the music generation by adding randomness with a high temperature and set a maximum music length >>> generate_kwargs = { diff --git a/src/transformers/pipelines/video_classification.py b/src/transformers/pipelines/video_classification.py index 1ee8dc86e161..ab57d46e7ccd 100644 --- a/src/transformers/pipelines/video_classification.py +++ b/src/transformers/pipelines/video_classification.py @@ -153,9 +153,8 @@ def preprocess(self, video, num_frames=None, frame_sampling_rate=1): video = read_video_pyav(container, indices) video = list(video) - model_inputs = self.image_processor(video, return_tensors=self.framework) - if self.framework == "pt": - model_inputs = model_inputs.to(self.dtype) + model_inputs = self.image_processor(video, return_tensors="pt") + model_inputs = model_inputs.to(self.dtype) return model_inputs def _forward(self, model_inputs): @@ -166,16 +165,13 @@ def postprocess(self, model_outputs, top_k=5, function_to_apply="softmax"): if top_k > self.model.config.num_labels: top_k = self.model.config.num_labels - if self.framework == "pt": - if function_to_apply == "softmax": - probs = model_outputs.logits[0].softmax(-1) - elif function_to_apply == "sigmoid": - probs = model_outputs.logits[0].sigmoid() - else: - probs = model_outputs.logits[0] - scores, ids = probs.topk(top_k) + if function_to_apply == "softmax": + probs = model_outputs.logits[0].softmax(-1) + elif function_to_apply == "sigmoid": + probs = model_outputs.logits[0].sigmoid() else: - raise ValueError(f"Unsupported framework: {self.framework}") + probs = model_outputs.logits[0] + scores, ids = probs.topk(top_k) scores = scores.tolist() ids = ids.tolist() diff --git a/src/transformers/pipelines/visual_question_answering.py b/src/transformers/pipelines/visual_question_answering.py index 609eaf2e9d55..c3f0514e1a8f 100644 --- a/src/transformers/pipelines/visual_question_answering.py +++ b/src/transformers/pipelines/visual_question_answering.py @@ -174,13 +174,12 @@ def preprocess(self, inputs, padding=False, truncation=False, timeout=None): image = load_image(inputs["image"], timeout=timeout) model_inputs = self.tokenizer( inputs["question"], - return_tensors=self.framework, + return_tensors="pt", padding=padding, truncation=truncation, ) - image_features = self.image_processor(images=image, return_tensors=self.framework) - if self.framework == "pt": - image_features = image_features.to(self.dtype) + image_features = self.image_processor(images=image, return_tensors="pt") + image_features = image_features.to(self.dtype) model_inputs.update(image_features) return model_inputs @@ -205,11 +204,8 @@ def postprocess(self, model_outputs, top_k=5): if top_k > self.model.config.num_labels: top_k = self.model.config.num_labels - if self.framework == "pt": - probs = model_outputs.logits.sigmoid()[0] - scores, ids = probs.topk(top_k) - else: - raise ValueError(f"Unsupported framework: {self.framework}") + probs = model_outputs.logits.sigmoid()[0] + scores, ids = probs.topk(top_k) scores = scores.tolist() ids = ids.tolist() diff --git a/src/transformers/pipelines/zero_shot_audio_classification.py b/src/transformers/pipelines/zero_shot_audio_classification.py index 9c21681a0d8e..fa9a2fe6ecfc 100644 --- a/src/transformers/pipelines/zero_shot_audio_classification.py +++ b/src/transformers/pipelines/zero_shot_audio_classification.py @@ -68,10 +68,6 @@ class ZeroShotAudioClassificationPipeline(Pipeline): def __init__(self, **kwargs): super().__init__(**kwargs) - if self.framework != "pt": - raise ValueError(f"The {self.__class__} is only available in PyTorch.") - # No specific FOR_XXX available yet - def __call__(self, audios: Union[np.ndarray, bytes, str, dict], **kwargs: Any) -> list[dict[str, Any]]: """ Assign labels to the audio(s) passed as inputs. @@ -127,11 +123,10 @@ def preprocess(self, audio, candidate_labels=None, hypothesis_template="This is inputs = self.feature_extractor( [audio], sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" ) - if self.framework == "pt": - inputs = inputs.to(self.dtype) + inputs = inputs.to(self.dtype) inputs["candidate_labels"] = candidate_labels sequences = [hypothesis_template.format(x) for x in candidate_labels] - text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True) + text_inputs = self.tokenizer(sequences, return_tensors="pt", padding=True) inputs["text_inputs"] = [text_inputs] return inputs @@ -156,11 +151,8 @@ def postprocess(self, model_outputs): candidate_labels = model_outputs.pop("candidate_labels") logits = model_outputs["logits"][0] - if self.framework == "pt": - probs = logits.softmax(dim=0) - scores = probs.tolist() - else: - raise ValueError("`tf` framework not supported.") + probs = logits.softmax(dim=0) + scores = probs.tolist() result = [ {"score": score, "label": candidate_label} diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index 20675d4a2928..7d30d85b61cf 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -109,7 +109,7 @@ def _parse_and_tokenize( """ Parse arguments and tokenize only_first so that hypothesis (label) is not truncated """ - return_tensors = self.framework + return_tensors = "pt" if self.tokenizer.pad_token is None: # Override for tokenizers not supporting padding logger.error( @@ -226,7 +226,7 @@ def _forward(self, inputs): sequence = inputs["sequence"] model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names} # `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported - model_forward = self.model.forward if self.framework == "pt" else self.model.call + model_forward = self.model.forward if "use_cache" in inspect.signature(model_forward).parameters: model_inputs["use_cache"] = False outputs = self.model(**model_inputs) @@ -242,10 +242,7 @@ def _forward(self, inputs): def postprocess(self, model_outputs, multi_label=False): candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] sequences = [outputs["sequence"] for outputs in model_outputs] - if self.framework == "pt": - logits = np.concatenate([output["logits"].float().numpy() for output in model_outputs]) - else: - logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) + logits = np.concatenate([output["logits"].float().numpy() for output in model_outputs]) N = logits.shape[0] n = len(candidate_labels) num_sequences = N // n diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index b864d2971cae..faaae0c32157 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -563,10 +563,8 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] object with processed inputs in a dict format. diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b232dcb76454..be77f1876f3c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -822,7 +822,6 @@ class TrainingArguments: "gradient_checkpointing_kwargs", "lr_scheduler_kwargs", ] - framework = "pt" output_dir: Optional[str] = field( default=None, @@ -1705,7 +1704,7 @@ def __post_init__(self): self.metric_for_best_model = "loss" if self.greater_is_better is None and self.metric_for_best_model is not None: self.greater_is_better = not self.metric_for_best_model.endswith("loss") - if self.framework == "pt" and is_torch_available(): + if is_torch_available(): if self.fp16_backend and self.fp16_backend != "auto": warnings.warn( "`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" @@ -1787,7 +1786,7 @@ def __post_init__(self): ) # Initialize device before we proceed - if self.framework == "pt" and is_torch_available(): + if is_torch_available(): self.device if self.torchdynamo is not None: @@ -1813,7 +1812,7 @@ def __post_init__(self): if self.torch_compile_mode is not None: os.environ[prefix + "MODE"] = self.torch_compile_mode - if self.framework == "pt" and is_torch_available() and self.torch_compile: + if is_torch_available() and self.torch_compile: if is_torch_tf32_available(): if self.tf32 is None and not self.fp16 or self.bf16: device_str = "MUSA" if is_torch_musa_available() else "CUDA" @@ -1830,7 +1829,7 @@ def __post_init__(self): logger.warning( "The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here." ) - if self.framework == "pt" and is_torch_available() and self.tf32 is not None: + if is_torch_available() and self.tf32 is not None: if self.tf32: if is_torch_tf32_available(): if is_torch_musa_available(): diff --git a/tests/models/fsmt/test_modeling_fsmt.py b/tests/models/fsmt/test_modeling_fsmt.py index 39a2d5b26a24..df57cc1dba83 100644 --- a/tests/models/fsmt/test_modeling_fsmt.py +++ b/tests/models/fsmt/test_modeling_fsmt.py @@ -545,7 +545,7 @@ def test_translation_direct(self, pair): @slow def test_translation_pipeline(self, pair): tokenizer, model, src_text, tgt_text = self.translation_setup(pair) - pipeline = TranslationPipeline(model, tokenizer, framework="pt", device=torch_device) + pipeline = TranslationPipeline(model, tokenizer, device=torch_device) output = pipeline([src_text]) self.assertEqual([tgt_text], [x["translation_text"] for x in output]) diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index 8f938cb7b0f7..3387e785ccbe 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -597,7 +597,7 @@ def test_batch_generation_en_ROMANCE_multi(self): @slow @require_torch def test_pipeline(self): - pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt", device=torch_device) + pipeline = TranslationPipeline(self.model, self.tokenizer, device=torch_device) output = pipeline(self.src_text) self.assertEqual(self.expected_text, [x["translation_text"] for x in output]) diff --git a/tests/models/upernet/test_modeling_upernet.py b/tests/models/upernet/test_modeling_upernet.py index 9bca31677f36..349766fe575e 100644 --- a/tests/models/upernet/test_modeling_upernet.py +++ b/tests/models/upernet/test_modeling_upernet.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Testing suite for the PyTorch UperNet framework.""" +"""Testing suite for the PyTorch UperNet.""" import unittest diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index c7aa7b686b1f..8776961bec14 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -160,7 +160,7 @@ def run_pipeline_test(self, speech_recognizer, examples): @require_torch def test_pt_defaults(self): - pipeline("automatic-speech-recognition", framework="pt") + pipeline("automatic-speech-recognition") @require_torch def test_small_model_pt(self): @@ -168,7 +168,6 @@ def test_small_model_pt(self): task="automatic-speech-recognition", model="facebook/s2t-small-mustc-en-fr-st", tokenizer="facebook/s2t-small-mustc-en-fr-st", - framework="pt", ) waveform = np.tile(np.arange(1000, dtype=np.float32), 34) output = speech_recognizer(waveform) @@ -188,7 +187,6 @@ def test_small_model_pt_fp16(self): task="automatic-speech-recognition", model="facebook/s2t-small-mustc-en-fr-st", tokenizer="facebook/s2t-small-mustc-en-fr-st", - framework="pt", dtype=torch.float16, ) waveform = np.tile(np.arange(1000, dtype=np.float32), 34) @@ -209,7 +207,6 @@ def test_small_model_pt_bf16(self): task="automatic-speech-recognition", model="facebook/s2t-small-mustc-en-fr-st", tokenizer="facebook/s2t-small-mustc-en-fr-st", - framework="pt", dtype=torch.bfloat16, ) waveform = np.tile(np.arange(1000, dtype=np.float32), 34) @@ -239,7 +236,6 @@ def test_whisper_fp16(self): def test_small_model_pt_seq2seq(self): speech_recognizer = pipeline( model="hf-internal-testing/tiny-random-speech-encoder-decoder", - framework="pt", max_new_tokens=19, num_beams=1, ) @@ -252,7 +248,6 @@ def test_small_model_pt_seq2seq(self): def test_small_model_pt_seq2seq_gen_kwargs(self): speech_recognizer = pipeline( model="hf-internal-testing/tiny-random-speech-encoder-decoder", - framework="pt", max_new_tokens=10, ) @@ -269,7 +264,6 @@ def test_large_model_pt_with_lm(self): speech_recognizer = pipeline( task="automatic-speech-recognition", model="patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm", - framework="pt", ) self.assertEqual(speech_recognizer.type, "ctc_with_lm") @@ -333,7 +327,6 @@ def test_torch_small_no_tokenizer_files(self): pipeline( task="automatic-speech-recognition", model="patrickvonplaten/tiny-wav2vec2-no-tokenizer", - framework="pt", ) @require_torch @@ -343,7 +336,6 @@ def test_torch_large(self): task="automatic-speech-recognition", model="facebook/wav2vec2-base-960h", tokenizer="facebook/wav2vec2-base-960h", - framework="pt", ) waveform = np.tile(np.arange(1000, dtype=np.float32), 34) output = speech_recognizer(waveform) @@ -360,7 +352,6 @@ def test_torch_large_with_input_features(self): speech_recognizer = pipeline( task="automatic-speech-recognition", model="hf-audio/wav2vec2-bert-CV16-en", - framework="pt", ) waveform = np.tile(np.arange(1000, dtype=np.float32), 34) output = speech_recognizer(waveform) @@ -600,7 +591,6 @@ def test_torch_whisper(self): speech_recognizer = pipeline( task="automatic-speech-recognition", model="openai/whisper-tiny", - framework="pt", num_beams=1, ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") @@ -617,7 +607,6 @@ def test_torch_whisper_batched(self): speech_recognizer = pipeline( task="automatic-speech-recognition", model="openai/whisper-tiny", - framework="pt", num_beams=1, ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:2]") @@ -913,7 +902,6 @@ def test_torch_speech_encoder_decoder(self): task="automatic-speech-recognition", model="facebook/s2t-wav2vec2-large-en-de", feature_extractor="facebook/s2t-wav2vec2-large-en-de", - framework="pt", ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") @@ -977,7 +965,6 @@ def test_simple_whisper_asr(self): speech_recognizer = pipeline( task="automatic-speech-recognition", model="openai/whisper-tiny.en", - framework="pt", num_beams=1, ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") @@ -1047,7 +1034,6 @@ def test_simple_whisper_translation(self): speech_recognizer = pipeline( task="automatic-speech-recognition", model="openai/whisper-large", - framework="pt", ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") audio = ds[40]["audio"] @@ -1083,7 +1069,6 @@ def test_whisper_language(self): speech_recognizer = pipeline( task="automatic-speech-recognition", model="openai/whisper-tiny.en", - framework="pt", ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") audio = ds[0]["audio"] @@ -1107,7 +1092,6 @@ def test_whisper_language(self): speech_recognizer = pipeline( task="automatic-speech-recognition", model="openai/whisper-tiny", - framework="pt", ) output = speech_recognizer(ds[0]["audio"], generate_kwargs={"language": "en"}) self.assertEqual( @@ -1207,7 +1191,6 @@ def test_xls_r_to_en(self): task="automatic-speech-recognition", model="facebook/wav2vec2-xls-r-1b-21-to-en", feature_extractor="facebook/wav2vec2-xls-r-1b-21-to-en", - framework="pt", ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") @@ -1223,7 +1206,6 @@ def test_xls_r_from_en(self): task="automatic-speech-recognition", model="facebook/wav2vec2-xls-r-1b-en-to-15", feature_extractor="facebook/wav2vec2-xls-r-1b-en-to-15", - framework="pt", ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") @@ -1240,7 +1222,6 @@ def test_speech_to_text_leveraged(self): model="patrickvonplaten/wav2vec2-2-bart-base", feature_extractor="patrickvonplaten/wav2vec2-2-bart-base", tokenizer=AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2-2-bart-base"), - framework="pt", ) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") @@ -1256,7 +1237,6 @@ def test_wav2vec2_conformer_float16(self): model="facebook/wav2vec2-conformer-rope-large-960h-ft", device=torch_device, dtype=torch.float16, - framework="pt", ) dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") @@ -1516,7 +1496,6 @@ def test_chunking_and_timestamps(self): model=model, tokenizer=tokenizer, feature_extractor=feature_extractor, - framework="pt", chunk_length_s=10.0, ) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 2ef189cd2956..d94f09987b20 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -101,9 +101,7 @@ def __len__(self): def __getitem__(self, i): return self.data[i] - text_classifier = pipeline( - task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt" - ) + text_classifier = pipeline(task="text-classification", model="hf-internal-testing/tiny-random-distilbert") dataset = MyDataset() for output in text_classifier(dataset): self.assertEqual(output, {"label": ANY(str), "score": ANY(float)}) @@ -248,9 +246,7 @@ class PipelineScikitCompatTest(unittest.TestCase): def test_pipeline_predict(self): data = ["This is a test"] - text_classifier = pipeline( - task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt" - ) + text_classifier = pipeline(task="text-classification", model="hf-internal-testing/tiny-random-distilbert") expected_output = [{"label": ANY(str), "score": ANY(float)}] actual_output = text_classifier.predict(data) @@ -259,9 +255,7 @@ def test_pipeline_predict(self): def test_pipeline_transform(self): data = ["This is a test"] - text_classifier = pipeline( - task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt" - ) + text_classifier = pipeline(task="text-classification", model="hf-internal-testing/tiny-random-distilbert") expected_output = [{"label": ANY(str), "score": ANY(float)}] actual_output = text_classifier.transform(data) @@ -640,7 +634,6 @@ def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equa if len(relevant_auto_classes) == 0: # task has no default - logger.debug(f"{task} in {framework} has no default") self.skipTest(f"{task} in {framework} has no default") # by default use first class @@ -680,14 +673,12 @@ def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equa # load default pipeline set_seed_fn() - default_pipeline = pipeline(task, framework=framework) + default_pipeline = pipeline(task) # compare pipeline model with default model models_are_equal = check_models_equal_fn(default_pipeline.model, model) self.assertTrue(models_are_equal, f"{task} model doesn't match pipeline.") - logger.debug(f"{task} in {framework} succeeded with {model_id}.") - def check_models_equal_pt(self, model1, model2): models_are_equal = True for model1_p, model2_p in zip(model1.parameters(), model2.parameters()): diff --git a/tests/pipelines/test_pipelines_feature_extraction.py b/tests/pipelines/test_pipelines_feature_extraction.py index a8321da15eb8..2d8a5618eaf2 100644 --- a/tests/pipelines/test_pipelines_feature_extraction.py +++ b/tests/pipelines/test_pipelines_feature_extraction.py @@ -38,9 +38,7 @@ class FeatureExtractionPipelineTests(unittest.TestCase): @require_torch def test_small_model_pt(self): - feature_extractor = pipeline( - task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert", framework="pt" - ) + feature_extractor = pipeline(task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert") outputs = feature_extractor("This is a test") self.assertEqual( nested_simplify(outputs), @@ -48,9 +46,7 @@ def test_small_model_pt(self): @require_torch def test_tokenization_small_model_pt(self): - feature_extractor = pipeline( - task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert", framework="pt" - ) + feature_extractor = pipeline(task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert") # test with empty parameters outputs = feature_extractor("This is a test") self.assertEqual( @@ -88,9 +84,7 @@ def test_tokenization_small_model_pt(self): @require_torch def test_return_tensors_pt(self): - feature_extractor = pipeline( - task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert", framework="pt" - ) + feature_extractor = pipeline(task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert") outputs = feature_extractor("This is a test", return_tensors=True) self.assertTrue(torch.is_tensor(outputs)) diff --git a/tests/pipelines/test_pipelines_fill_mask.py b/tests/pipelines/test_pipelines_fill_mask.py index af5ee8c4be0c..4c85529c0613 100644 --- a/tests/pipelines/test_pipelines_fill_mask.py +++ b/tests/pipelines/test_pipelines_fill_mask.py @@ -44,7 +44,7 @@ def tearDown(self): @require_torch def test_small_model_pt(self): - unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", top_k=2, framework="pt") + unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", top_k=2) outputs = unmasker("My name is ") self.assertEqual( @@ -111,7 +111,6 @@ def test_fp16_casting(self): "fill-mask", model="hf-internal-testing/tiny-random-distilbert", device=torch_device, - framework="pt", ) # convert model to fp16 @@ -126,7 +125,7 @@ def test_fp16_casting(self): @slow @require_torch def test_large_model_pt(self): - unmasker = pipeline(task="fill-mask", model="distilbert/distilroberta-base", top_k=2, framework="pt") + unmasker = pipeline(task="fill-mask", model="distilbert/distilroberta-base", top_k=2) self.run_large_test(unmasker) def run_large_test(self, unmasker): @@ -190,7 +189,7 @@ def run_large_test(self, unmasker): @require_torch def test_model_no_pad_pt(self): - unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", framework="pt") + unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base") unmasker.tokenizer.pad_token_id = None unmasker.tokenizer.pad_token = None self.run_pipeline_test(unmasker, []) diff --git a/tests/pipelines/test_pipelines_image_feature_extraction.py b/tests/pipelines/test_pipelines_image_feature_extraction.py index 2705e1385331..c80bb2ec3453 100644 --- a/tests/pipelines/test_pipelines_image_feature_extraction.py +++ b/tests/pipelines/test_pipelines_image_feature_extraction.py @@ -47,9 +47,7 @@ class ImageFeatureExtractionPipelineTests(unittest.TestCase): @require_torch def test_small_model_pt(self): - feature_extractor = pipeline( - task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt" - ) + feature_extractor = pipeline(task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit") img = prepare_img() outputs = feature_extractor(img) self.assertEqual( @@ -59,7 +57,7 @@ def test_small_model_pt(self): @require_torch def test_small_model_w_pooler_pt(self): feature_extractor = pipeline( - task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit-w-pooler", framework="pt" + task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit-w-pooler" ) img = prepare_img() outputs = feature_extractor(img, pool=True) @@ -69,9 +67,7 @@ def test_small_model_w_pooler_pt(self): @require_torch def test_image_processing_small_model_pt(self): - feature_extractor = pipeline( - task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt" - ) + feature_extractor = pipeline(task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit") # test with image processor parameters image_processor_kwargs = {"size": {"height": 300, "width": 300}} @@ -91,9 +87,7 @@ def test_image_processing_small_model_pt(self): @require_torch def test_return_tensors_pt(self): - feature_extractor = pipeline( - task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt" - ) + feature_extractor = pipeline(task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit") img = prepare_img() outputs = feature_extractor(img, return_tensors=True) self.assertTrue(torch.is_tensor(outputs)) diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index b87742ee03cf..fd02ae8eea94 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -200,7 +200,7 @@ def test_small_model_pt_bf16(self): @require_torch def test_small_model_pt_iterator(self): # https://github.com/huggingface/transformers/issues/18510 - pipe = pipeline(model="sshleifer/tiny-distilbert-base-cased-distilled-squad", batch_size=16, framework="pt") + pipe = pipeline(model="sshleifer/tiny-distilbert-base-cased-distilled-squad", batch_size=16) def data(): for i in range(10): diff --git a/tests/pipelines/test_pipelines_summarization.py b/tests/pipelines/test_pipelines_summarization.py index e58d9264b89d..02cd9fe084c6 100644 --- a/tests/pipelines/test_pipelines_summarization.py +++ b/tests/pipelines/test_pipelines_summarization.py @@ -82,7 +82,7 @@ def run_pipeline_test(self, summarizer, _): @require_torch def test_small_model_pt(self): - summarizer = pipeline(task="summarization", model="sshleifer/tiny-mbart", framework="pt", max_new_tokens=19) + summarizer = pipeline(task="summarization", model="sshleifer/tiny-mbart", max_new_tokens=19) outputs = summarizer("This is a small test") self.assertEqual( outputs, diff --git a/tests/pipelines/test_pipelines_text2text_generation.py b/tests/pipelines/test_pipelines_text2text_generation.py index 730f707237db..b52d68e22cfb 100644 --- a/tests/pipelines/test_pipelines_text2text_generation.py +++ b/tests/pipelines/test_pipelines_text2text_generation.py @@ -87,7 +87,6 @@ def test_small_model_pt(self): generator = pipeline( "text2text-generation", model="patrickvonplaten/t5-tiny-random", - framework="pt", num_beams=1, max_new_tokens=9, ) diff --git a/tests/pipelines/test_pipelines_text_classification.py b/tests/pipelines/test_pipelines_text_classification.py index 902cc59b0987..b2310e8e087f 100644 --- a/tests/pipelines/test_pipelines_text_classification.py +++ b/tests/pipelines/test_pipelines_text_classification.py @@ -50,9 +50,7 @@ class TextClassificationPipelineTests(unittest.TestCase): @require_torch def test_small_model_pt(self): - text_classifier = pipeline( - task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt" - ) + text_classifier = pipeline(task="text-classification", model="hf-internal-testing/tiny-random-distilbert") outputs = text_classifier("This is great !") self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) @@ -112,7 +110,6 @@ def test_accepts_torch_device(self): text_classifier = pipeline( task="text-classification", model="hf-internal-testing/tiny-random-distilbert", - framework="pt", device=torch_device, ) @@ -124,7 +121,6 @@ def test_accepts_torch_fp16(self): text_classifier = pipeline( task="text-classification", model="hf-internal-testing/tiny-random-distilbert", - framework="pt", device=torch_device, dtype=torch.float16, ) @@ -137,7 +133,6 @@ def test_accepts_torch_bf16(self): text_classifier = pipeline( task="text-classification", model="hf-internal-testing/tiny-random-distilbert", - framework="pt", device=torch_device, dtype=torch.bfloat16, ) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 983b1e60a097..f0f576364c41 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -43,7 +43,6 @@ def test_small_model_pt(self): text_generator = pipeline( task="text-generation", model="hf-internal-testing/tiny-random-LlamaForCausalLM", - framework="pt", max_new_tokens=10, ) # Using `do_sample=False` to force deterministic output @@ -73,7 +72,6 @@ def test_small_chat_model_pt(self): text_generator = pipeline( task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", - framework="pt", ) # Using `do_sample=False` to force deterministic output chat1 = [ @@ -121,7 +119,6 @@ def test_small_chat_model_continue_final_message(self): text_generator = pipeline( task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", - framework="pt", ) # Using `do_sample=False` to force deterministic output chat1 = [ @@ -155,7 +152,6 @@ def test_small_chat_model_continue_final_message_override(self): text_generator = pipeline( task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", - framework="pt", ) # Using `do_sample=False` to force deterministic output chat1 = [ @@ -203,7 +199,6 @@ def __getitem__(self, i): text_generator = pipeline( task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", - framework="pt", ) dataset = MyDataset() @@ -230,7 +225,6 @@ def test_small_chat_model_with_iterator_pt(self): text_generator = pipeline( task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", - framework="pt", ) # Using `do_sample=False` to force deterministic output diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index e435ce800fe0..fb56d0c64b54 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -43,7 +43,7 @@ class TextToAudioPipelineTests(unittest.TestCase): @require_torch def test_small_musicgen_pt(self): music_generator = pipeline( - task="text-to-audio", model="facebook/musicgen-small", framework="pt", do_sample=False, max_new_tokens=5 + task="text-to-audio", model="facebook/musicgen-small", do_sample=False, max_new_tokens=5 ) outputs = music_generator("This is a test") @@ -55,7 +55,7 @@ def test_small_musicgen_pt(self): self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) # test batching, this time with parameterization in the forward pass - music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt") + music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small") forward_params = {"do_sample": False, "max_new_tokens": 5} outputs = music_generator( ["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2 @@ -66,9 +66,7 @@ def test_small_musicgen_pt(self): @slow @require_torch def test_medium_seamless_m4t_pt(self): - speech_generator = pipeline( - task="text-to-audio", model="facebook/hf-seamless-m4t-medium", framework="pt", max_new_tokens=5 - ) + speech_generator = pipeline(task="text-to-audio", model="facebook/hf-seamless-m4t-medium", max_new_tokens=5) for forward_params in [{"tgt_lang": "eng"}, {"return_intermediate_token_ids": True, "tgt_lang": "eng"}]: outputs = speech_generator("This is a test", forward_params=forward_params) @@ -89,7 +87,7 @@ def test_medium_seamless_m4t_pt(self): @slow @require_torch def test_small_bark_pt(self): - speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt") + speech_generator = pipeline(task="text-to-audio", model="suno/bark-small") forward_params = { # Using `do_sample=False` to force deterministic output @@ -139,7 +137,7 @@ def test_small_bark_pt(self): @slow @require_torch_accelerator def test_conversion_additional_tensor(self): - speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt", device=torch_device) + speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", device=torch_device) processor = AutoProcessor.from_pretrained("suno/bark-small") forward_params = { @@ -177,7 +175,7 @@ def test_conversion_additional_tensor(self): @require_torch def test_vits_model_pt(self): - speech_generator = pipeline(task="text-to-audio", model="facebook/mms-tts-eng", framework="pt") + speech_generator = pipeline(task="text-to-audio", model="facebook/mms-tts-eng") outputs = speech_generator("This is a test") self.assertEqual(outputs["sampling_rate"], 16000) @@ -197,7 +195,7 @@ def test_vits_model_pt(self): @require_torch def test_forward_model_kwargs(self): # use vits - a forward model - speech_generator = pipeline(task="text-to-audio", model="kakao-enterprise/vits-vctk", framework="pt") + speech_generator = pipeline(task="text-to-audio", model="kakao-enterprise/vits-vctk") # for reproducibility set_seed(555) @@ -221,7 +219,7 @@ def test_forward_model_kwargs(self): @require_torch def test_generative_model_kwargs(self): # use musicgen - a generative model - music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt") + music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small") forward_params = { "do_sample": True, diff --git a/tests/pipelines/test_pipelines_token_classification.py b/tests/pipelines/test_pipelines_token_classification.py index 19e6d342805c..5f04d3857cb8 100644 --- a/tests/pipelines/test_pipelines_token_classification.py +++ b/tests/pipelines/test_pipelines_token_classification.py @@ -544,7 +544,7 @@ def test_aggregation_strategy_byte_level_tokenizer(self): def test_aggregation_strategy_no_b_i_prefix(self): model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) - token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt") + token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer) # Just to understand scores indexes in this test token_classifier.model.config.id2label = {0: "O", 1: "MISC", 2: "PER", 3: "ORG", 4: "LOC"} example = [ @@ -593,7 +593,7 @@ def test_aggregation_strategy_no_b_i_prefix(self): def test_aggregation_strategy(self): model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) - token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt") + token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer) # Just to understand scores indexes in this test self.assertEqual( token_classifier.model.config.id2label, @@ -666,7 +666,7 @@ def test_aggregation_strategy(self): def test_aggregation_strategy_example2(self): model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) - token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt") + token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer) # Just to understand scores indexes in this test self.assertEqual( token_classifier.model.config.id2label, @@ -742,7 +742,7 @@ def test_aggregation_strategy_offsets_with_leading_space(self): def test_gather_pre_entities(self): model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) - token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt") + token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer) sentence = "Hello there" @@ -787,7 +787,7 @@ def test_gather_pre_entities(self): def test_word_heuristic_leading_space(self): model_name = "hf-internal-testing/tiny-random-deberta-v2" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) - token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt") + token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer) sentence = "I play the theremin" @@ -822,7 +822,7 @@ def test_word_heuristic_leading_space(self): def test_no_offset_tokenizer(self): model_name = "hf-internal-testing/tiny-bert-for-token-classification" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) - token_classifier = pipeline(task="token-classification", model=model_name, tokenizer=tokenizer, framework="pt") + token_classifier = pipeline(task="token-classification", model=model_name, tokenizer=tokenizer) outputs = token_classifier("This is a test !") self.assertEqual( nested_simplify(outputs), @@ -835,7 +835,7 @@ def test_no_offset_tokenizer(self): @require_torch def test_small_model_pt(self): model_name = "hf-internal-testing/tiny-bert-for-token-classification" - token_classifier = pipeline(task="token-classification", model=model_name, framework="pt") + token_classifier = pipeline(task="token-classification", model=model_name) outputs = token_classifier("This is a test !") self.assertEqual( nested_simplify(outputs), @@ -845,16 +845,14 @@ def test_small_model_pt(self): ], ) - token_classifier = pipeline( - task="token-classification", model=model_name, framework="pt", ignore_labels=["O", "I-MISC"] - ) + token_classifier = pipeline(task="token-classification", model=model_name, ignore_labels=["O", "I-MISC"]) outputs = token_classifier("This is a test !") self.assertEqual( nested_simplify(outputs), [], ) - token_classifier = pipeline(task="token-classification", model=model_name, framework="pt") + token_classifier = pipeline(task="token-classification", model=model_name) # Overload offset_mapping outputs = token_classifier( "This is a test !", offset_mapping=[(0, 0), (0, 1), (0, 2), (0, 0), (0, 0), (0, 0), (0, 0)] @@ -887,7 +885,7 @@ def test_small_model_pt(self): @require_torch def test_small_model_pt_fp16(self): model_name = "hf-internal-testing/tiny-bert-for-token-classification" - token_classifier = pipeline(task="token-classification", model=model_name, framework="pt", dtype=torch.float16) + token_classifier = pipeline(task="token-classification", model=model_name, dtype=torch.float16) outputs = token_classifier("This is a test !") self.assertEqual( nested_simplify(outputs), @@ -900,9 +898,7 @@ def test_small_model_pt_fp16(self): @require_torch def test_small_model_pt_bf16(self): model_name = "hf-internal-testing/tiny-bert-for-token-classification" - token_classifier = pipeline( - task="token-classification", model=model_name, framework="pt", dtype=torch.bfloat16 - ) + token_classifier = pipeline(task="token-classification", model=model_name, dtype=torch.bfloat16) outputs = token_classifier("This is a test !") self.assertEqual( nested_simplify(outputs), diff --git a/tests/pipelines/test_pipelines_translation.py b/tests/pipelines/test_pipelines_translation.py index 3b95cce1f70d..9f3225bee72e 100644 --- a/tests/pipelines/test_pipelines_translation.py +++ b/tests/pipelines/test_pipelines_translation.py @@ -79,7 +79,7 @@ def run_pipeline_test(self, translator, _): @require_torch def test_small_model_pt(self): - translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="pt") + translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random") outputs = translator("This is a test string", max_length=20) self.assertEqual( outputs, @@ -95,7 +95,7 @@ def test_small_model_pt(self): @require_torch def test_en_to_de_pt(self): - translator = pipeline("translation_en_to_de", model="patrickvonplaten/t5-tiny-random", framework="pt") + translator = pipeline("translation_en_to_de", model="patrickvonplaten/t5-tiny-random") outputs = translator("This is a test string", max_length=20) self.assertEqual( outputs, diff --git a/tests/pipelines/test_pipelines_zero_shot.py b/tests/pipelines/test_pipelines_zero_shot.py index 479854dba972..9591936cedcd 100644 --- a/tests/pipelines/test_pipelines_zero_shot.py +++ b/tests/pipelines/test_pipelines_zero_shot.py @@ -165,7 +165,6 @@ def test_truncation(self): zero_shot_classifier = pipeline( "zero-shot-classification", model="sshleifer/tiny-distilbert-base-cased-distilled-squad", - framework="pt", ) # There was a regression in 4.10 for this # Adding a test so we don't make the mistake again. @@ -179,7 +178,6 @@ def test_small_model_pt(self): zero_shot_classifier = pipeline( "zero-shot-classification", model="sshleifer/tiny-distilbert-base-cased-distilled-squad", - framework="pt", ) outputs = zero_shot_classifier( "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] @@ -199,7 +197,6 @@ def test_small_model_pt_fp16(self): zero_shot_classifier = pipeline( "zero-shot-classification", model="sshleifer/tiny-distilbert-base-cased-distilled-squad", - framework="pt", dtype=torch.float16, ) outputs = zero_shot_classifier( @@ -220,7 +217,6 @@ def test_small_model_pt_bf16(self): zero_shot_classifier = pipeline( "zero-shot-classification", model="sshleifer/tiny-distilbert-base-cased-distilled-squad", - framework="pt", dtype=torch.bfloat16, ) outputs = zero_shot_classifier( @@ -239,9 +235,7 @@ def test_small_model_pt_bf16(self): @slow @require_torch def test_large_model_pt(self): - zero_shot_classifier = pipeline( - "zero-shot-classification", model="FacebookAI/roberta-large-mnli", framework="pt" - ) + zero_shot_classifier = pipeline("zero-shot-classification", model="FacebookAI/roberta-large-mnli") outputs = zero_shot_classifier( "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] ) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 06ea41d9ec2c..a68395b0238d 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -473,7 +473,6 @@ def check_config_attributes_being_used(config_class): # Get the path to modeling source files config_source_file = inspect.getsourcefile(config_class) model_dir = os.path.dirname(config_source_file) - # Let's check against all frameworks: as long as one framework uses an attribute, we are good. modeling_paths = [os.path.join(model_dir, fn) for fn in os.listdir(model_dir) if fn.startswith("modeling_")] # Get the source code strings diff --git a/utils/create_dummy_models.py b/utils/create_dummy_models.py index 5e0239fb5c60..a561967fba10 100644 --- a/utils/create_dummy_models.py +++ b/utils/create_dummy_models.py @@ -62,7 +62,6 @@ raise ValueError("Please install PyTorch.") -FRAMEWORKS = ["pytorch"] INVALID_ARCH = [] TARGET_VOCAB_SIZE = 1024 @@ -773,11 +772,11 @@ def fill_result_with_error(result, error, trace, models_to_create): """Fill `result` with errors for all target model arch if we can't build processor""" error = (error, trace) result["error"] = error - for framework in FRAMEWORKS: - if framework in models_to_create: - result[framework] = {} - for model_arch in models_to_create[framework]: - result[framework][model_arch.__name__] = {"model": None, "checkpoint": None, "error": error} + + if "pytorch" in models_to_create: + result["pytorch"] = {} + for model_arch in models_to_create["pytorch"]: + result["pytorch"][model_arch.__name__] = {"model": None, "checkpoint": None, "error": error} result["processor"] = {p.__class__.__name__: p.__class__.__name__ for p in result["processor"].values()} @@ -1055,7 +1054,7 @@ def build(config_class, models_to_create, output_dir): of the same model type which is associated to `config_class`. output_dir (`str`): The directory to save all the checkpoints. Each model architecture will be saved in a subdirectory under - it. Models in different frameworks with the same architecture will be saved in the same subdirectory. + it. """ if data["training_ds"] is None or data["testing_ds"] is None: ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1") @@ -1208,40 +1207,40 @@ def build_tiny_model_summary(results, organization=None, token=None): processors = [key for key, value in results[config_name]["processor"].items()] tokenizer_classes = sorted([x for x in processors if x.endswith("TokenizerFast") or x.endswith("Tokenizer")]) processor_classes = sorted([x for x in processors if x not in tokenizer_classes]) - for framework in FRAMEWORKS: - if framework not in results[config_name]: - continue - for arch_name in results[config_name][framework]: - model_classes = [arch_name] - base_arch_name = arch_name - # tiny model is not created for `arch_name` - if results[config_name][framework][arch_name]["model"] is None: - model_classes = [] - if base_arch_name not in tiny_model_summary: - tiny_model_summary[base_arch_name] = {} - tiny_model_summary[base_arch_name].update( - { - "tokenizer_classes": tokenizer_classes, - "processor_classes": processor_classes, - } - ) - tiny_model_summary[base_arch_name]["model_classes"] = sorted( - tiny_model_summary[base_arch_name].get("model_classes", []) + model_classes - ) - if organization is not None: - repo_name = f"tiny-random-{base_arch_name}" - # composite models' checkpoints have more precise repo. names on the Hub. - if base_arch_name in COMPOSITE_MODELS: - repo_name = f"tiny-random-{COMPOSITE_MODELS[base_arch_name]}" - repo_id = f"{organization}/{repo_name}" - try: - commit_hash = hf_api.repo_info(repo_id, token=token).sha - except Exception: - # The directory is not created, but processor(s) is/are included in `results`. - logger.warning(f"Failed to get information for {repo_id}.\n{traceback.format_exc()}") - del tiny_model_summary[base_arch_name] - continue - tiny_model_summary[base_arch_name]["sha"] = commit_hash + + if "pytorch" not in results[config_name]: + continue + for arch_name in results[config_name]["pytorch"]: + model_classes = [arch_name] + base_arch_name = arch_name + # tiny model is not created for `arch_name` + if results[config_name]["pytorch"][arch_name]["model"] is None: + model_classes = [] + if base_arch_name not in tiny_model_summary: + tiny_model_summary[base_arch_name] = {} + tiny_model_summary[base_arch_name].update( + { + "tokenizer_classes": tokenizer_classes, + "processor_classes": processor_classes, + } + ) + tiny_model_summary[base_arch_name]["model_classes"] = sorted( + tiny_model_summary[base_arch_name].get("model_classes", []) + model_classes + ) + if organization is not None: + repo_name = f"tiny-random-{base_arch_name}" + # composite models' checkpoints have more precise repo. names on the Hub. + if base_arch_name in COMPOSITE_MODELS: + repo_name = f"tiny-random-{COMPOSITE_MODELS[base_arch_name]}" + repo_id = f"{organization}/{repo_name}" + try: + commit_hash = hf_api.repo_info(repo_id, token=token).sha + except Exception: + # The directory is not created, but processor(s) is/are included in `results`. + logger.warning(f"Failed to get information for {repo_id}.\n{traceback.format_exc()}") + del tiny_model_summary[base_arch_name] + continue + tiny_model_summary[base_arch_name]["sha"] = commit_hash return tiny_model_summary @@ -1259,19 +1258,18 @@ def build_failed_report(results, include_warning=True): failed_results[config_name] = {} failed_results[config_name]["warnings"] = results[config_name]["warnings"] - for framework in FRAMEWORKS: - if framework not in results[config_name]: - continue - for arch_name in results[config_name][framework]: - if "error" in results[config_name][framework][arch_name]: - if config_name not in failed_results: - failed_results[config_name] = {} - if framework not in failed_results[config_name]: - failed_results[config_name][framework] = {} - if arch_name not in failed_results[config_name][framework]: - failed_results[config_name][framework][arch_name] = {} - error = results[config_name][framework][arch_name]["error"] - failed_results[config_name][framework][arch_name]["error"] = error + if "pytorch" not in results[config_name]: + continue + for arch_name in results[config_name]["pytorch"]: + if "error" in results[config_name]["pytorch"][arch_name]: + if config_name not in failed_results: + failed_results[config_name] = {} + if "pytorch" not in failed_results[config_name]: + failed_results[config_name]["pytorch"] = {} + if arch_name not in failed_results[config_name]["pytorch"]: + failed_results[config_name]["pytorch"][arch_name] = {} + error = results[config_name]["pytorch"][arch_name]["error"] + failed_results[config_name]["pytorch"][arch_name]["error"] = error return failed_results @@ -1280,16 +1278,15 @@ def build_simple_report(results): text = "" failed_text = "" for config_name in results: - for framework in FRAMEWORKS: - if framework not in results[config_name]: - continue - for arch_name in results[config_name][framework]: - if "error" in results[config_name][framework][arch_name]: - result = results[config_name][framework][arch_name]["error"] - failed_text += f"{arch_name}: {result[0]}\n" - else: - result = ("OK",) - text += f"{arch_name}: {result[0]}\n" + if "pytorch" not in results[config_name]: + continue + for arch_name in results[config_name]["pytorch"]: + if "error" in results[config_name]["pytorch"][arch_name]: + result = results[config_name]["pytorch"][arch_name]["error"] + failed_text += f"{arch_name}: {result[0]}\n" + else: + result = ("OK",) + text += f"{arch_name}: {result[0]}\n" return text, failed_text diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index 25bc2fde3663..942313d57bb2 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -319,7 +319,6 @@ src/transformers/commands/run.py src/transformers/commands/serving.py src/transformers/commands/transformers_cli.py src/transformers/configuration_utils.py -src/transformers/convert_graph_to_onnx.py src/transformers/convert_slow_tokenizer.py src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py src/transformers/data/data_collator.py diff --git a/utils/test_module/custom_pipeline.py b/utils/test_module/custom_pipeline.py index 4c7928b1ccd1..1fbb05ff804d 100644 --- a/utils/test_module/custom_pipeline.py +++ b/utils/test_module/custom_pipeline.py @@ -17,7 +17,7 @@ def _sanitize_parameters(self, **kwargs): return preprocess_kwargs, {}, {} def preprocess(self, text, second_text=None): - return self.tokenizer(text, text_pair=second_text, return_tensors=self.framework) + return self.tokenizer(text, text_pair=second_text, return_tensors="pt") def _forward(self, model_inputs): return self.model(**model_inputs) diff --git a/utils/update_metadata.py b/utils/update_metadata.py index 0e80ca2b866d..e122b64ce849 100755 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -211,27 +211,20 @@ def update_pipeline_and_auto_class_table(table: dict[str, tuple[str, str]]) -> d Returns: `Dict[str, Tuple[str, str]]`: The updated table in the same format. """ - auto_modules = [ - transformers_module.models.auto.modeling_auto, - ] - for pipeline_tag, model_mapping, auto_class in PIPELINE_TAGS_AND_AUTO_MODELS: - model_mappings = [model_mapping] - auto_classes = [auto_class] - # Loop through all three frameworks - for module, cls, mapping in zip(auto_modules, auto_classes, model_mappings): - # The type of pipeline may not exist in this framework - if not hasattr(module, mapping): - continue - # First extract all model_names - model_names = [] - for name in getattr(module, mapping).values(): - if isinstance(name, str): - model_names.append(name) - else: - model_names.extend(list(name)) - - # Add pipeline tag and auto model class for those models - table.update(dict.fromkeys(model_names, (pipeline_tag, cls))) + module = transformers_module.models.auto.modeling_auto + for pipeline_tag, model_mapping, cls in PIPELINE_TAGS_AND_AUTO_MODELS: + if not hasattr(module, model_mapping): + continue + # First extract all model_names + model_names = [] + for name in getattr(module, model_mapping).values(): + if isinstance(name, str): + model_names.append(name) + else: + model_names.extend(list(name)) + + # Add pipeline tag and auto model class for those models + table.update(dict.fromkeys(model_names, (pipeline_tag, cls))) return table diff --git a/utils/update_tiny_models.py b/utils/update_tiny_models.py index 9dc4cdf6e6b2..c7d62f8e94b8 100644 --- a/utils/update_tiny_models.py +++ b/utils/update_tiny_models.py @@ -21,7 +21,6 @@ """ import argparse -import copy import json import multiprocessing import os @@ -62,23 +61,12 @@ def get_all_model_names(): def get_tiny_model_names_from_repo(): - # All model names defined in auto mappings - model_names = set(get_all_model_names()) - with open("tests/utils/tiny_model_summary.json") as fp: tiny_model_info = json.load(fp) tiny_models_names = set() for model_base_name in tiny_model_info: tiny_models_names.update(tiny_model_info[model_base_name]["model_classes"]) - # Remove a tiny model name if one of its framework implementation hasn't yet a tiny version on the Hub. - not_on_hub = model_names.difference(tiny_models_names) - for model_name in copy.copy(tiny_models_names): - if not model_name.startswith("TF") and f"TF{model_name}" in not_on_hub: - tiny_models_names.remove(model_name) - elif model_name.startswith("TF") and model_name[2:] in not_on_hub: - tiny_models_names.remove(model_name) - return sorted(tiny_models_names) From 14daddd6e8f6342c542fdba0731b0b4b74260343 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 10 Sep 2025 23:32:00 +0200 Subject: [PATCH 13/35] in tje zone --- .../image_processing_new_imgproc_model.py | 2 -- .../feature_extraction_audio_spectrogram_transformer.py | 1 - src/transformers/models/beit/image_processing_beit.py | 2 -- src/transformers/models/bit/image_processing_bit.py | 2 -- src/transformers/models/blip/image_processing_blip.py | 2 -- .../models/bridgetower/image_processing_bridgetower.py | 4 ---- .../models/chameleon/image_processing_chameleon.py | 2 -- .../models/chinese_clip/image_processing_chinese_clip.py | 2 -- src/transformers/models/clap/feature_extraction_clap.py | 1 - src/transformers/models/clip/image_processing_clip.py | 2 -- src/transformers/models/clvp/feature_extraction_clvp.py | 1 - src/transformers/models/convnext/image_processing_convnext.py | 2 -- src/transformers/models/dac/feature_extraction_dac.py | 1 - .../models/deepseek_vl/image_processing_deepseek_vl.py | 2 -- .../deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py | 2 -- .../models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py | 2 -- src/transformers/models/deit/image_processing_deit.py | 2 -- .../efficientformer/image_processing_efficientformer.py | 2 -- .../models/deprecated/mctct/feature_extraction_mctct.py | 1 - .../models/deprecated/tapex/tokenization_tapex.py | 1 - .../models/deprecated/tvlt/image_processing_tvlt.py | 2 -- .../deprecated/vit_hybrid/image_processing_vit_hybrid.py | 2 -- .../models/depth_pro/image_processing_depth_pro.py | 2 -- src/transformers/models/dia/feature_extraction_dia.py | 1 - src/transformers/models/donut/image_processing_donut.py | 2 -- src/transformers/models/dpr/tokenization_dpr.py | 1 - src/transformers/models/dpr/tokenization_dpr_fast.py | 1 - src/transformers/models/dpt/image_processing_dpt.py | 2 -- .../models/efficientloftr/image_processing_efficientloftr.py | 2 -- .../models/efficientnet/image_processing_efficientnet.py | 2 -- src/transformers/models/emu3/image_processing_emu3.py | 2 -- src/transformers/models/encodec/feature_extraction_encodec.py | 1 - src/transformers/models/flava/image_processing_flava.py | 2 -- src/transformers/models/fuyu/image_processing_fuyu.py | 2 -- src/transformers/models/gemma3/image_processing_gemma3.py | 2 -- src/transformers/models/glm4v/image_processing_glm4v.py | 2 -- src/transformers/models/glpn/image_processing_glpn.py | 2 -- src/transformers/models/got_ocr2/image_processing_got_ocr2.py | 2 -- src/transformers/models/idefics2/image_processing_idefics2.py | 4 ---- src/transformers/models/idefics3/image_processing_idefics3.py | 4 ---- src/transformers/models/imagegpt/image_processing_imagegpt.py | 2 -- .../instructblipvideo/image_processing_instructblipvideo.py | 2 -- src/transformers/models/janus/image_processing_janus.py | 2 -- .../models/kosmos2_5/image_processing_kosmos2_5.py | 2 -- .../feature_extraction_kyutai_speech_to_text.py | 1 - .../kyutai_speech_to_text/modular_kyutai_speech_to_text.py | 1 - .../models/layoutlmv2/image_processing_layoutlmv2.py | 2 -- src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py | 1 - .../models/layoutlmv3/image_processing_layoutlmv3.py | 2 -- src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py | 2 -- src/transformers/models/layoutxlm/tokenization_layoutxlm.py | 1 - .../models/layoutxlm/tokenization_layoutxlm_fast.py | 1 - src/transformers/models/levit/image_processing_levit.py | 2 -- .../models/lightglue/image_processing_lightglue.py | 2 -- src/transformers/models/llava/image_processing_llava.py | 2 -- .../models/llava_next/image_processing_llava_next.py | 2 -- .../llava_next_video/image_processing_llava_next_video.py | 2 -- .../llava_onevision/image_processing_llava_onevision.py | 2 -- src/transformers/models/markuplm/tokenization_markuplm.py | 1 - .../models/mask2former/image_processing_mask2former.py | 2 -- .../models/maskformer/image_processing_maskformer.py | 2 -- src/transformers/models/mllama/image_processing_mllama.py | 2 -- .../models/mobilenet_v1/image_processing_mobilenet_v1.py | 2 -- .../models/mobilenet_v2/image_processing_mobilenet_v2.py | 2 -- .../models/mobilevit/image_processing_mobilevit.py | 2 -- .../musicgen_melody/feature_extraction_musicgen_melody.py | 1 - src/transformers/models/nougat/image_processing_nougat.py | 2 -- .../models/oneformer/image_processing_oneformer.py | 2 -- src/transformers/models/ovis2/image_processing_ovis2.py | 2 -- src/transformers/models/owlv2/image_processing_owlv2.py | 2 -- src/transformers/models/owlvit/image_processing_owlvit.py | 2 -- .../models/perceiver/image_processing_perceiver.py | 2 -- .../phi4_multimodal/feature_extraction_phi4_multimodal.py | 1 - .../models/pix2struct/image_processing_pix2struct.py | 2 -- src/transformers/models/pixtral/image_processing_pixtral.py | 2 -- .../models/poolformer/image_processing_poolformer.py | 2 -- src/transformers/models/pop2piano/tokenization_pop2piano.py | 1 - .../image_processing_prompt_depth_anything.py | 2 -- src/transformers/models/pvt/image_processing_pvt.py | 2 -- src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py | 2 -- src/transformers/models/rag/retrieval_rag.py | 1 - .../models/seamless_m4t/feature_extraction_seamless_m4t.py | 1 - .../models/segformer/image_processing_segformer.py | 2 -- src/transformers/models/seggpt/image_processing_seggpt.py | 4 ---- src/transformers/models/siglip/image_processing_siglip.py | 2 -- src/transformers/models/siglip2/image_processing_siglip2.py | 2 -- src/transformers/models/smolvlm/image_processing_smolvlm.py | 4 ---- .../speech_to_text/feature_extraction_speech_to_text.py | 1 - .../models/speecht5/feature_extraction_speecht5.py | 1 - .../models/superglue/image_processing_superglue.py | 2 -- .../models/superpoint/image_processing_superpoint.py | 2 -- src/transformers/models/swin2sr/image_processing_swin2sr.py | 3 --- src/transformers/models/tapas/tokenization_tapas.py | 1 - src/transformers/models/textnet/image_processing_textnet.py | 2 -- src/transformers/models/tvp/image_processing_tvp.py | 2 -- src/transformers/models/udop/tokenization_udop.py | 1 - src/transformers/models/udop/tokenization_udop_fast.py | 1 - src/transformers/models/univnet/feature_extraction_univnet.py | 1 - .../models/video_llava/image_processing_video_llava.py | 2 -- src/transformers/models/videomae/image_processing_videomae.py | 2 -- src/transformers/models/vilt/image_processing_vilt.py | 4 ---- src/transformers/models/vit/image_processing_vit.py | 2 -- src/transformers/models/vitmatte/image_processing_vitmatte.py | 2 -- src/transformers/models/vivit/image_processing_vivit.py | 2 -- .../models/wav2vec2/feature_extraction_wav2vec2.py | 1 - src/transformers/models/whisper/feature_extraction_whisper.py | 1 - src/transformers/models/zoedepth/image_processing_zoedepth.py | 2 -- 107 files changed, 198 deletions(-) diff --git a/examples/modular-transformers/image_processing_new_imgproc_model.py b/examples/modular-transformers/image_processing_new_imgproc_model.py index d795dc5b32ab..7dae62f883f8 100644 --- a/examples/modular-transformers/image_processing_new_imgproc_model.py +++ b/examples/modular-transformers/image_processing_new_imgproc_model.py @@ -194,10 +194,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py index f56c0c3213b7..b7ff6fa08e2f 100644 --- a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py @@ -179,7 +179,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. """ diff --git a/src/transformers/models/beit/image_processing_beit.py b/src/transformers/models/beit/image_processing_beit.py index a93051b00b25..0029480e46d6 100644 --- a/src/transformers/models/beit/image_processing_beit.py +++ b/src/transformers/models/beit/image_processing_beit.py @@ -360,10 +360,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/bit/image_processing_bit.py b/src/transformers/models/bit/image_processing_bit.py index 3d32752edca8..2c24f3f1b969 100644 --- a/src/transformers/models/bit/image_processing_bit.py +++ b/src/transformers/models/bit/image_processing_bit.py @@ -226,10 +226,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/blip/image_processing_blip.py b/src/transformers/models/blip/image_processing_blip.py index 78a152374fd0..ca04b75583b0 100644 --- a/src/transformers/models/blip/image_processing_blip.py +++ b/src/transformers/models/blip/image_processing_blip.py @@ -204,10 +204,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower.py b/src/transformers/models/bridgetower/image_processing_bridgetower.py index cb39ed097561..b681292ae6d3 100644 --- a/src/transformers/models/bridgetower/image_processing_bridgetower.py +++ b/src/transformers/models/bridgetower/image_processing_bridgetower.py @@ -339,10 +339,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -431,10 +429,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/chameleon/image_processing_chameleon.py b/src/transformers/models/chameleon/image_processing_chameleon.py index 9cae9d7bdd34..4e9b2e1c1755 100644 --- a/src/transformers/models/chameleon/image_processing_chameleon.py +++ b/src/transformers/models/chameleon/image_processing_chameleon.py @@ -218,10 +218,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py index c55805f28913..3bb47ffe97b3 100644 --- a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py +++ b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py @@ -219,10 +219,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/clap/feature_extraction_clap.py b/src/transformers/models/clap/feature_extraction_clap.py index e333248c18ed..75c79c4e3834 100644 --- a/src/transformers/models/clap/feature_extraction_clap.py +++ b/src/transformers/models/clap/feature_extraction_clap.py @@ -290,7 +290,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.np.array` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/clip/image_processing_clip.py b/src/transformers/models/clip/image_processing_clip.py index ea17e4a65ff4..08b27680c8e9 100644 --- a/src/transformers/models/clip/image_processing_clip.py +++ b/src/transformers/models/clip/image_processing_clip.py @@ -253,10 +253,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/clvp/feature_extraction_clvp.py b/src/transformers/models/clvp/feature_extraction_clvp.py index 160666ef78c0..077e70af67b1 100644 --- a/src/transformers/models/clvp/feature_extraction_clvp.py +++ b/src/transformers/models/clvp/feature_extraction_clvp.py @@ -170,7 +170,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. padding_value (`float`, *optional*, defaults to 0.0): diff --git a/src/transformers/models/convnext/image_processing_convnext.py b/src/transformers/models/convnext/image_processing_convnext.py index af89274500dd..0ac9a8af06e0 100644 --- a/src/transformers/models/convnext/image_processing_convnext.py +++ b/src/transformers/models/convnext/image_processing_convnext.py @@ -234,10 +234,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/dac/feature_extraction_dac.py b/src/transformers/models/dac/feature_extraction_dac.py index e81a2466dd91..21af67e2233a 100644 --- a/src/transformers/models/dac/feature_extraction_dac.py +++ b/src/transformers/models/dac/feature_extraction_dac.py @@ -92,7 +92,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py index 7ab4e98012ac..dca5eeb296ee 100644 --- a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py @@ -250,10 +250,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py index 7c7d6df82424..0ac9602eb72c 100644 --- a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py @@ -288,10 +288,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py index 5149f7f42178..3cc62d7449e5 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py @@ -600,10 +600,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/deit/image_processing_deit.py b/src/transformers/models/deit/image_processing_deit.py index 1e2f6c3b5ae5..dbb83412c563 100644 --- a/src/transformers/models/deit/image_processing_deit.py +++ b/src/transformers/models/deit/image_processing_deit.py @@ -211,10 +211,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - `None`: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py b/src/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py index a2dd1281e920..ec9e4f6f3695 100644 --- a/src/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py +++ b/src/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py @@ -225,10 +225,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/deprecated/mctct/feature_extraction_mctct.py b/src/transformers/models/deprecated/mctct/feature_extraction_mctct.py index 966a160f91b0..0ce7c1da31a2 100644 --- a/src/transformers/models/deprecated/mctct/feature_extraction_mctct.py +++ b/src/transformers/models/deprecated/mctct/feature_extraction_mctct.py @@ -207,7 +207,6 @@ def __call__( return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/deprecated/tapex/tokenization_tapex.py b/src/transformers/models/deprecated/tapex/tokenization_tapex.py index b32383ddd497..fa74d8aa3b55 100644 --- a/src/transformers/models/deprecated/tapex/tokenization_tapex.py +++ b/src/transformers/models/deprecated/tapex/tokenization_tapex.py @@ -93,7 +93,6 @@ class TapexTruncationStrategy(ExplicitEnum): return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. """ diff --git a/src/transformers/models/deprecated/tvlt/image_processing_tvlt.py b/src/transformers/models/deprecated/tvlt/image_processing_tvlt.py index c0e1a33f091b..224a35eb0e79 100644 --- a/src/transformers/models/deprecated/tvlt/image_processing_tvlt.py +++ b/src/transformers/models/deprecated/tvlt/image_processing_tvlt.py @@ -332,10 +332,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py b/src/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py index 92d518363b2c..27b1f83a7ba8 100644 --- a/src/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py +++ b/src/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py @@ -243,10 +243,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/depth_pro/image_processing_depth_pro.py b/src/transformers/models/depth_pro/image_processing_depth_pro.py index 47f224f248bd..2b81c95db818 100644 --- a/src/transformers/models/depth_pro/image_processing_depth_pro.py +++ b/src/transformers/models/depth_pro/image_processing_depth_pro.py @@ -232,10 +232,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/dia/feature_extraction_dia.py b/src/transformers/models/dia/feature_extraction_dia.py index b4376b773b27..dcb32d2be6f4 100644 --- a/src/transformers/models/dia/feature_extraction_dia.py +++ b/src/transformers/models/dia/feature_extraction_dia.py @@ -92,7 +92,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/donut/image_processing_donut.py b/src/transformers/models/donut/image_processing_donut.py index 7dec96422c5d..75bcb9cda994 100644 --- a/src/transformers/models/donut/image_processing_donut.py +++ b/src/transformers/models/donut/image_processing_donut.py @@ -363,10 +363,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/dpr/tokenization_dpr.py b/src/transformers/models/dpr/tokenization_dpr.py index 020b235cb6bd..1a87ef9fd915 100644 --- a/src/transformers/models/dpr/tokenization_dpr.py +++ b/src/transformers/models/dpr/tokenization_dpr.py @@ -112,7 +112,6 @@ class DPRQuestionEncoderTokenizer(BertTokenizer): return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. return_attention_mask (`bool`, *optional*): diff --git a/src/transformers/models/dpr/tokenization_dpr_fast.py b/src/transformers/models/dpr/tokenization_dpr_fast.py index dbf745291745..5f501dbdd4f0 100644 --- a/src/transformers/models/dpr/tokenization_dpr_fast.py +++ b/src/transformers/models/dpr/tokenization_dpr_fast.py @@ -113,7 +113,6 @@ class DPRQuestionEncoderTokenizerFast(BertTokenizerFast): return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. return_attention_mask (`bool`, *optional*): diff --git a/src/transformers/models/dpt/image_processing_dpt.py b/src/transformers/models/dpt/image_processing_dpt.py index 189b34e9fb7b..70440017e518 100644 --- a/src/transformers/models/dpt/image_processing_dpt.py +++ b/src/transformers/models/dpt/image_processing_dpt.py @@ -495,10 +495,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/efficientloftr/image_processing_efficientloftr.py b/src/transformers/models/efficientloftr/image_processing_efficientloftr.py index 2146ea8b39eb..45f28220ae84 100644 --- a/src/transformers/models/efficientloftr/image_processing_efficientloftr.py +++ b/src/transformers/models/efficientloftr/image_processing_efficientloftr.py @@ -259,10 +259,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/efficientnet/image_processing_efficientnet.py b/src/transformers/models/efficientnet/image_processing_efficientnet.py index ea822d75ca27..5331f4a4dea0 100644 --- a/src/transformers/models/efficientnet/image_processing_efficientnet.py +++ b/src/transformers/models/efficientnet/image_processing_efficientnet.py @@ -264,10 +264,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - `None`: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/emu3/image_processing_emu3.py b/src/transformers/models/emu3/image_processing_emu3.py index aaf3afa41733..be4decd410dc 100644 --- a/src/transformers/models/emu3/image_processing_emu3.py +++ b/src/transformers/models/emu3/image_processing_emu3.py @@ -329,10 +329,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/encodec/feature_extraction_encodec.py b/src/transformers/models/encodec/feature_extraction_encodec.py index 3cc8d523f7f0..1086bdfb228e 100644 --- a/src/transformers/models/encodec/feature_extraction_encodec.py +++ b/src/transformers/models/encodec/feature_extraction_encodec.py @@ -116,7 +116,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/flava/image_processing_flava.py b/src/transformers/models/flava/image_processing_flava.py index 7b4db246a8fa..1e3fecfd9740 100644 --- a/src/transformers/models/flava/image_processing_flava.py +++ b/src/transformers/models/flava/image_processing_flava.py @@ -564,10 +564,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/fuyu/image_processing_fuyu.py b/src/transformers/models/fuyu/image_processing_fuyu.py index e52d9dc8ee91..a1aa184a3b3c 100644 --- a/src/transformers/models/fuyu/image_processing_fuyu.py +++ b/src/transformers/models/fuyu/image_processing_fuyu.py @@ -414,10 +414,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format of the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/gemma3/image_processing_gemma3.py b/src/transformers/models/gemma3/image_processing_gemma3.py index 8addbbfd378c..02db120c8414 100644 --- a/src/transformers/models/gemma3/image_processing_gemma3.py +++ b/src/transformers/models/gemma3/image_processing_gemma3.py @@ -285,10 +285,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/glm4v/image_processing_glm4v.py b/src/transformers/models/glm4v/image_processing_glm4v.py index 8293545deee2..ad6549826fe2 100644 --- a/src/transformers/models/glm4v/image_processing_glm4v.py +++ b/src/transformers/models/glm4v/image_processing_glm4v.py @@ -352,10 +352,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/glpn/image_processing_glpn.py b/src/transformers/models/glpn/image_processing_glpn.py index e3e0255e2b47..2e2d82a23322 100644 --- a/src/transformers/models/glpn/image_processing_glpn.py +++ b/src/transformers/models/glpn/image_processing_glpn.py @@ -166,10 +166,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - `None`: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/got_ocr2/image_processing_got_ocr2.py b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py index 209ac88ea2fb..6880cfe208a0 100644 --- a/src/transformers/models/got_ocr2/image_processing_got_ocr2.py +++ b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py @@ -309,10 +309,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/idefics2/image_processing_idefics2.py b/src/transformers/models/idefics2/image_processing_idefics2.py index 3f0db7644563..22befb6ceaed 100644 --- a/src/transformers/models/idefics2/image_processing_idefics2.py +++ b/src/transformers/models/idefics2/image_processing_idefics2.py @@ -303,10 +303,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -444,10 +442,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/idefics3/image_processing_idefics3.py b/src/transformers/models/idefics3/image_processing_idefics3.py index e460a041965a..e298f0890c99 100644 --- a/src/transformers/models/idefics3/image_processing_idefics3.py +++ b/src/transformers/models/idefics3/image_processing_idefics3.py @@ -546,10 +546,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -657,10 +655,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. return_row_col_info (`bool`, *optional*, default to `False`): Whether to return the number of rows and columns of the split images. This is used for the `Idefics3Processor` to generate prompt strings based on the number of rows and columns. diff --git a/src/transformers/models/imagegpt/image_processing_imagegpt.py b/src/transformers/models/imagegpt/image_processing_imagegpt.py index 9168ecaceff2..42f3a296103b 100644 --- a/src/transformers/models/imagegpt/image_processing_imagegpt.py +++ b/src/transformers/models/imagegpt/image_processing_imagegpt.py @@ -213,10 +213,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py index 56391b59dbdd..3e1ee362aad8 100644 --- a/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py @@ -207,10 +207,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/janus/image_processing_janus.py b/src/transformers/models/janus/image_processing_janus.py index 16659bd85354..b995cdcf8b92 100644 --- a/src/transformers/models/janus/image_processing_janus.py +++ b/src/transformers/models/janus/image_processing_janus.py @@ -247,10 +247,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py index 18c087a4b368..a16a1eba626a 100644 --- a/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py +++ b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py @@ -267,10 +267,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py index bde1736f9da8..fa0ce5e11ded 100644 --- a/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py @@ -126,7 +126,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py index 8541a911e947..16e8f6cd6dcb 100644 --- a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py @@ -105,7 +105,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py index de2e7361a6d3..c959583ae5bc 100644 --- a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py @@ -234,10 +234,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py index 7d82b5cf4104..a4c04598d855 100644 --- a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py @@ -81,7 +81,6 @@ return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. """ diff --git a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py index 8189abf67311..855ae52e8075 100644 --- a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py @@ -279,10 +279,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py index b69fc57b1743..fdf95a34d58d 100644 --- a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py @@ -85,7 +85,6 @@ return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. """ @@ -134,7 +133,6 @@ return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. """ diff --git a/src/transformers/models/layoutxlm/tokenization_layoutxlm.py b/src/transformers/models/layoutxlm/tokenization_layoutxlm.py index 3dbe3c21a636..9c1d5c05a9f9 100644 --- a/src/transformers/models/layoutxlm/tokenization_layoutxlm.py +++ b/src/transformers/models/layoutxlm/tokenization_layoutxlm.py @@ -84,7 +84,6 @@ return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. return_token_type_ids (`bool`, *optional*): diff --git a/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py b/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py index 6710c6c8cb66..7b08a3aa5f0e 100644 --- a/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py +++ b/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py @@ -86,7 +86,6 @@ return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. return_token_type_ids (`bool`, *optional*): diff --git a/src/transformers/models/levit/image_processing_levit.py b/src/transformers/models/levit/image_processing_levit.py index 5bf03b39e4b9..09c23d3bad91 100644 --- a/src/transformers/models/levit/image_processing_levit.py +++ b/src/transformers/models/levit/image_processing_levit.py @@ -226,10 +226,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. If unset, the channel dimension format of the input image is used. Can be one of: diff --git a/src/transformers/models/lightglue/image_processing_lightglue.py b/src/transformers/models/lightglue/image_processing_lightglue.py index ce925ea173dd..4263c3712407 100644 --- a/src/transformers/models/lightglue/image_processing_lightglue.py +++ b/src/transformers/models/lightglue/image_processing_lightglue.py @@ -260,10 +260,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/llava/image_processing_llava.py b/src/transformers/models/llava/image_processing_llava.py index d3aa81303bb8..a77e7649b7d0 100644 --- a/src/transformers/models/llava/image_processing_llava.py +++ b/src/transformers/models/llava/image_processing_llava.py @@ -334,10 +334,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/llava_next/image_processing_llava_next.py b/src/transformers/models/llava_next/image_processing_llava_next.py index 3887c9c7ad4b..7b8bc4a513ef 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next.py +++ b/src/transformers/models/llava_next/image_processing_llava_next.py @@ -606,10 +606,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/llava_next_video/image_processing_llava_next_video.py b/src/transformers/models/llava_next_video/image_processing_llava_next_video.py index ba1cd30a1133..8468c20afa4e 100644 --- a/src/transformers/models/llava_next_video/image_processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/image_processing_llava_next_video.py @@ -326,10 +326,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py index 837eda460802..65098512366b 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py @@ -650,10 +650,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/markuplm/tokenization_markuplm.py b/src/transformers/models/markuplm/tokenization_markuplm.py index a090e11ec36d..0a6f7c3bd6a0 100644 --- a/src/transformers/models/markuplm/tokenization_markuplm.py +++ b/src/transformers/models/markuplm/tokenization_markuplm.py @@ -83,7 +83,6 @@ return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. """ diff --git a/src/transformers/models/mask2former/image_processing_mask2former.py b/src/transformers/models/mask2former/image_processing_mask2former.py index a0c369722b54..bebab8b9e2da 100644 --- a/src/transformers/models/mask2former/image_processing_mask2former.py +++ b/src/transformers/models/mask2former/image_processing_mask2former.py @@ -858,10 +858,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): diff --git a/src/transformers/models/maskformer/image_processing_maskformer.py b/src/transformers/models/maskformer/image_processing_maskformer.py index 9ce33846170e..f94202f47243 100644 --- a/src/transformers/models/maskformer/image_processing_maskformer.py +++ b/src/transformers/models/maskformer/image_processing_maskformer.py @@ -860,10 +860,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): diff --git a/src/transformers/models/mllama/image_processing_mllama.py b/src/transformers/models/mllama/image_processing_mllama.py index ba1a596aa459..a331e6d5319d 100644 --- a/src/transformers/models/mllama/image_processing_mllama.py +++ b/src/transformers/models/mllama/image_processing_mllama.py @@ -655,10 +655,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. Returns: `BatchFeature` of the following structure: diff --git a/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py b/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py index 6fa3f443c53b..897a5f2074d9 100644 --- a/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +++ b/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py @@ -216,10 +216,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py index 665173d59b95..3224c2665704 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py @@ -377,10 +377,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit.py b/src/transformers/models/mobilevit/image_processing_mobilevit.py index bbf76ff04023..93ff9bc6a1c2 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit.py @@ -373,10 +373,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py b/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py index ec23899e91e9..744471bab553 100644 --- a/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py @@ -211,7 +211,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. return_attention_mask (`bool`, *optional*): diff --git a/src/transformers/models/nougat/image_processing_nougat.py b/src/transformers/models/nougat/image_processing_nougat.py index 0c0a51464b43..793f450484fb 100644 --- a/src/transformers/models/nougat/image_processing_nougat.py +++ b/src/transformers/models/nougat/image_processing_nougat.py @@ -419,10 +419,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/oneformer/image_processing_oneformer.py b/src/transformers/models/oneformer/image_processing_oneformer.py index 615c71593062..7311e017e213 100644 --- a/src/transformers/models/oneformer/image_processing_oneformer.py +++ b/src/transformers/models/oneformer/image_processing_oneformer.py @@ -815,10 +815,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): diff --git a/src/transformers/models/ovis2/image_processing_ovis2.py b/src/transformers/models/ovis2/image_processing_ovis2.py index bd6d63e83914..ce776d24c7db 100644 --- a/src/transformers/models/ovis2/image_processing_ovis2.py +++ b/src/transformers/models/ovis2/image_processing_ovis2.py @@ -367,10 +367,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/owlv2/image_processing_owlv2.py b/src/transformers/models/owlv2/image_processing_owlv2.py index 64399d433f5e..19a04cdeb871 100644 --- a/src/transformers/models/owlv2/image_processing_owlv2.py +++ b/src/transformers/models/owlv2/image_processing_owlv2.py @@ -407,10 +407,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/owlvit/image_processing_owlvit.py b/src/transformers/models/owlvit/image_processing_owlvit.py index 58253073d5d8..0af755297afd 100644 --- a/src/transformers/models/owlvit/image_processing_owlvit.py +++ b/src/transformers/models/owlvit/image_processing_owlvit.py @@ -355,10 +355,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/perceiver/image_processing_perceiver.py b/src/transformers/models/perceiver/image_processing_perceiver.py index c66d7b51d463..acc72f0ab877 100644 --- a/src/transformers/models/perceiver/image_processing_perceiver.py +++ b/src/transformers/models/perceiver/image_processing_perceiver.py @@ -258,10 +258,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py index 71ada3a8c62a..16b5875eb001 100644 --- a/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py @@ -110,7 +110,6 @@ def __call__( If set, will return tensors instead of numpy arrays. Acceptable values are: - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. - - `'tf'`: Return TensorFlow `tf.constant` objects. return_attention_mask (`bool`, *optional*, defaults to `True`): Whether to return the extracted audio input features' attention mask. device (`str`, *optional*, defaults to "cpu"): diff --git a/src/transformers/models/pix2struct/image_processing_pix2struct.py b/src/transformers/models/pix2struct/image_processing_pix2struct.py index f1cc2ba2068b..610ac36ea086 100644 --- a/src/transformers/models/pix2struct/image_processing_pix2struct.py +++ b/src/transformers/models/pix2struct/image_processing_pix2struct.py @@ -376,10 +376,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index c6c6fdb163ab..4c08f533e9e2 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -366,10 +366,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/poolformer/image_processing_poolformer.py b/src/transformers/models/poolformer/image_processing_poolformer.py index ee5500c823cc..b8ea103a5366 100644 --- a/src/transformers/models/poolformer/image_processing_poolformer.py +++ b/src/transformers/models/poolformer/image_processing_poolformer.py @@ -261,10 +261,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/pop2piano/tokenization_pop2piano.py b/src/transformers/models/pop2piano/tokenization_pop2piano.py index f7aea3479f6f..c81165a03be4 100644 --- a/src/transformers/models/pop2piano/tokenization_pop2piano.py +++ b/src/transformers/models/pop2piano/tokenization_pop2piano.py @@ -542,7 +542,6 @@ def __call__( return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. verbose (`bool`, *optional*, defaults to `True`): diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index 4f0f68240f9a..985fe290f689 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -336,10 +336,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/pvt/image_processing_pvt.py b/src/transformers/models/pvt/image_processing_pvt.py index 9f687fe7548f..c86727fff290 100644 --- a/src/transformers/models/pvt/image_processing_pvt.py +++ b/src/transformers/models/pvt/image_processing_pvt.py @@ -189,10 +189,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py index 552b289f58c0..70af072d1bee 100644 --- a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py @@ -360,10 +360,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/rag/retrieval_rag.py b/src/transformers/models/rag/retrieval_rag.py index e397d111a0a4..6fb924c8b7bc 100644 --- a/src/transformers/models/rag/retrieval_rag.py +++ b/src/transformers/models/rag/retrieval_rag.py @@ -610,7 +610,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to "pt"): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. diff --git a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py index 59a20059f869..d0a8a07b3a9e 100644 --- a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py @@ -204,7 +204,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/segformer/image_processing_segformer.py b/src/transformers/models/segformer/image_processing_segformer.py index 4025b59e7ebb..15c1b899e7dd 100644 --- a/src/transformers/models/segformer/image_processing_segformer.py +++ b/src/transformers/models/segformer/image_processing_segformer.py @@ -343,10 +343,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/seggpt/image_processing_seggpt.py b/src/transformers/models/seggpt/image_processing_seggpt.py index ffadfaf85edb..8ced740d3825 100644 --- a/src/transformers/models/seggpt/image_processing_seggpt.py +++ b/src/transformers/models/seggpt/image_processing_seggpt.py @@ -285,10 +285,8 @@ def _preprocess_step( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. @@ -453,10 +451,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/siglip/image_processing_siglip.py b/src/transformers/models/siglip/image_processing_siglip.py index 0ffed5258de5..60ee68e1eba2 100644 --- a/src/transformers/models/siglip/image_processing_siglip.py +++ b/src/transformers/models/siglip/image_processing_siglip.py @@ -152,10 +152,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/siglip2/image_processing_siglip2.py b/src/transformers/models/siglip2/image_processing_siglip2.py index 30b5f1b958af..ef8685b1c312 100644 --- a/src/transformers/models/siglip2/image_processing_siglip2.py +++ b/src/transformers/models/siglip2/image_processing_siglip2.py @@ -235,10 +235,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: diff --git a/src/transformers/models/smolvlm/image_processing_smolvlm.py b/src/transformers/models/smolvlm/image_processing_smolvlm.py index c08339b81732..36f42462540b 100644 --- a/src/transformers/models/smolvlm/image_processing_smolvlm.py +++ b/src/transformers/models/smolvlm/image_processing_smolvlm.py @@ -543,10 +543,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -654,10 +652,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. return_row_col_info (`bool`, *optional*, default to `False`): Whether to return the number of rows and columns of the split images. This is used for the `SmolVLMProcessor` to generate prompt strings based on the number of rows and columns. diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py index 64627033671e..fe6698e9ebec 100644 --- a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py +++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py @@ -229,7 +229,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/speecht5/feature_extraction_speecht5.py b/src/transformers/models/speecht5/feature_extraction_speecht5.py index 822ae01a88d7..3c30eed90bd1 100644 --- a/src/transformers/models/speecht5/feature_extraction_speecht5.py +++ b/src/transformers/models/speecht5/feature_extraction_speecht5.py @@ -233,7 +233,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/superglue/image_processing_superglue.py b/src/transformers/models/superglue/image_processing_superglue.py index f9192ac82df1..b6c79c814b97 100644 --- a/src/transformers/models/superglue/image_processing_superglue.py +++ b/src/transformers/models/superglue/image_processing_superglue.py @@ -261,10 +261,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/superpoint/image_processing_superpoint.py b/src/transformers/models/superpoint/image_processing_superpoint.py index f5da3335078e..47f2e2e4b473 100644 --- a/src/transformers/models/superpoint/image_processing_superpoint.py +++ b/src/transformers/models/superpoint/image_processing_superpoint.py @@ -218,10 +218,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/swin2sr/image_processing_swin2sr.py b/src/transformers/models/swin2sr/image_processing_swin2sr.py index 76c5e907da1c..def1c6d533ff 100644 --- a/src/transformers/models/swin2sr/image_processing_swin2sr.py +++ b/src/transformers/models/swin2sr/image_processing_swin2sr.py @@ -154,11 +154,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of typ, input_data_format=input_data_format - `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/tapas/tokenization_tapas.py b/src/transformers/models/tapas/tokenization_tapas.py index 4d91a1add944..44558f5542e4 100644 --- a/src/transformers/models/tapas/tokenization_tapas.py +++ b/src/transformers/models/tapas/tokenization_tapas.py @@ -142,7 +142,6 @@ def whitespace_tokenize(text): return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. """ diff --git a/src/transformers/models/textnet/image_processing_textnet.py b/src/transformers/models/textnet/image_processing_textnet.py index 153e29785289..7a461d7145c1 100644 --- a/src/transformers/models/textnet/image_processing_textnet.py +++ b/src/transformers/models/textnet/image_processing_textnet.py @@ -257,10 +257,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/tvp/image_processing_tvp.py b/src/transformers/models/tvp/image_processing_tvp.py index d3f698873d55..9ef3c93fb13f 100644 --- a/src/transformers/models/tvp/image_processing_tvp.py +++ b/src/transformers/models/tvp/image_processing_tvp.py @@ -399,10 +399,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/udop/tokenization_udop.py b/src/transformers/models/udop/tokenization_udop.py index 29b4b3ee24e4..a5833333e10a 100644 --- a/src/transformers/models/udop/tokenization_udop.py +++ b/src/transformers/models/udop/tokenization_udop.py @@ -86,7 +86,6 @@ return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. return_token_type_ids (`bool`, *optional*): diff --git a/src/transformers/models/udop/tokenization_udop_fast.py b/src/transformers/models/udop/tokenization_udop_fast.py index a8878b9b514c..9751f5d65ddf 100644 --- a/src/transformers/models/udop/tokenization_udop_fast.py +++ b/src/transformers/models/udop/tokenization_udop_fast.py @@ -85,7 +85,6 @@ return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. return_token_type_ids (`bool`, *optional*): diff --git a/src/transformers/models/univnet/feature_extraction_univnet.py b/src/transformers/models/univnet/feature_extraction_univnet.py index 059226afe1d1..6ff2b73df7a6 100644 --- a/src/transformers/models/univnet/feature_extraction_univnet.py +++ b/src/transformers/models/univnet/feature_extraction_univnet.py @@ -355,7 +355,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.np.array` objects. - `'np'`: Return Numpy `np.ndarray` objects. """ diff --git a/src/transformers/models/video_llava/image_processing_video_llava.py b/src/transformers/models/video_llava/image_processing_video_llava.py index 1ed8f911af8e..0e9141a70dfb 100644 --- a/src/transformers/models/video_llava/image_processing_video_llava.py +++ b/src/transformers/models/video_llava/image_processing_video_llava.py @@ -226,10 +226,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/videomae/image_processing_videomae.py b/src/transformers/models/videomae/image_processing_videomae.py index 96545dc75311..622ae7ab9c47 100644 --- a/src/transformers/models/videomae/image_processing_videomae.py +++ b/src/transformers/models/videomae/image_processing_videomae.py @@ -283,10 +283,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/vilt/image_processing_vilt.py b/src/transformers/models/vilt/image_processing_vilt.py index 87abf7f7a7d6..28f33cf9017a 100644 --- a/src/transformers/models/vilt/image_processing_vilt.py +++ b/src/transformers/models/vilt/image_processing_vilt.py @@ -305,10 +305,8 @@ def pad( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -389,10 +387,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/vit/image_processing_vit.py b/src/transformers/models/vit/image_processing_vit.py index 16216e2eac90..b7e118007ace 100644 --- a/src/transformers/models/vit/image_processing_vit.py +++ b/src/transformers/models/vit/image_processing_vit.py @@ -195,10 +195,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/vitmatte/image_processing_vitmatte.py b/src/transformers/models/vitmatte/image_processing_vitmatte.py index 6e65a634d23d..6a06df90a30f 100644 --- a/src/transformers/models/vitmatte/image_processing_vitmatte.py +++ b/src/transformers/models/vitmatte/image_processing_vitmatte.py @@ -188,10 +188,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/vivit/image_processing_vivit.py b/src/transformers/models/vivit/image_processing_vivit.py index ab32d5b47eef..51655a71cf51 100644 --- a/src/transformers/models/vivit/image_processing_vivit.py +++ b/src/transformers/models/vivit/image_processing_vivit.py @@ -338,10 +338,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py index 6bca69e82d09..3b830c314b31 100644 --- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py +++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -160,7 +160,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index bf548ac8408f..e11895191f95 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -236,7 +236,6 @@ def __call__( return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): diff --git a/src/transformers/models/zoedepth/image_processing_zoedepth.py b/src/transformers/models/zoedepth/image_processing_zoedepth.py index 973b5279822c..85da72325232 100644 --- a/src/transformers/models/zoedepth/image_processing_zoedepth.py +++ b/src/transformers/models/zoedepth/image_processing_zoedepth.py @@ -357,10 +357,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. From e970b64bd3431d2e62d73727614dad5da595396e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 00:20:49 +0200 Subject: [PATCH 14/35] never stop --- .../image_processing_new_imgproc_model.py | 5 +- setup.py | 5 +- src/transformers/__init__.py | 4 - src/transformers/dependency_versions_table.py | 5 +- src/transformers/file_utils.py | 3 - src/transformers/modeling_utils.py | 3 +- src/transformers/models/albert/__init__.py | 2 - .../models/aria/image_processing_aria.py | 5 +- src/transformers/models/aria/modular_aria.py | 5 +- src/transformers/models/auto/__init__.py | 2 - src/transformers/models/bart/__init__.py | 2 - src/transformers/models/beit/__init__.py | 1 - .../models/beit/image_processing_beit.py | 8 +- src/transformers/models/bert/__init__.py | 2 - src/transformers/models/big_bird/__init__.py | 1 - .../models/bit/image_processing_bit.py | 5 +- .../models/blenderbot/__init__.py | 2 - .../models/blenderbot_small/__init__.py | 2 - src/transformers/models/blip/__init__.py | 2 - .../models/blip/image_processing_blip.py | 5 +- src/transformers/models/bloom/__init__.py | 1 - .../image_processing_bridgetower.py | 5 +- src/transformers/models/camembert/__init__.py | 1 - .../chameleon/image_processing_chameleon.py | 5 +- .../image_processing_chinese_clip.py | 5 +- src/transformers/models/clip/__init__.py | 2 - .../models/clip/image_processing_clip.py | 5 +- .../models/codegen/modeling_codegen.py | 2 - src/transformers/models/convbert/__init__.py | 1 - src/transformers/models/convnext/__init__.py | 1 - .../convnext/image_processing_convnext.py | 5 +- .../models/convnextv2/__init__.py | 1 - src/transformers/models/ctrl/__init__.py | 1 - src/transformers/models/cvt/__init__.py | 1 - src/transformers/models/data2vec/__init__.py | 1 - src/transformers/models/deberta/__init__.py | 1 - .../models/deberta_v2/__init__.py | 1 - .../image_processing_deepseek_vl.py | 5 +- .../image_processing_deepseek_vl_hybrid.py | 5 +- .../modular_deepseek_vl_hybrid.py | 5 +- src/transformers/models/deit/__init__.py | 1 - .../models/deit/image_processing_deit.py | 5 +- .../deprecated/efficientformer/__init__.py | 1 - .../image_processing_efficientformer.py | 5 +- .../models/deprecated/mega/modeling_mega.py | 2 - .../models/deprecated/transfo_xl/__init__.py | 1 - .../deprecated/tvlt/image_processing_tvlt.py | 3 +- .../vit_hybrid/image_processing_vit_hybrid.py | 5 +- .../depth_pro/image_processing_depth_pro.py | 5 +- src/transformers/models/dinov2/__init__.py | 1 - .../models/distilbert/__init__.py | 2 - .../models/donut/image_processing_donut.py | 5 +- src/transformers/models/dpr/__init__.py | 1 - .../models/dpt/image_processing_dpt.py | 5 +- .../image_processing_efficientloftr.py | 5 +- .../image_processing_efficientnet.py | 5 +- src/transformers/models/electra/__init__.py | 2 - .../models/emu3/image_processing_emu3.py | 5 +- .../models/encoder_decoder/__init__.py | 2 - .../models/eomt/image_processing_eomt.py | 5 +- src/transformers/models/esm/__init__.py | 1 - .../models/esm/openfold_utils/tensor_utils.py | 1 - src/transformers/models/flaubert/__init__.py | 1 - .../models/flava/image_processing_flava.py | 5 +- src/transformers/models/funnel/__init__.py | 1 - src/transformers/models/gemma/__init__.py | 1 - .../models/gemma3/image_processing_gemma3.py | 5 +- .../gemma3n/feature_extraction_gemma3n.py | 2 +- .../models/glm4v/image_processing_glm4v.py | 5 +- .../models/glpn/image_processing_glpn.py | 5 +- .../got_ocr2/image_processing_got_ocr2.py | 5 +- src/transformers/models/gpt2/__init__.py | 2 - src/transformers/models/gpt_neo/__init__.py | 1 - src/transformers/models/gptj/__init__.py | 2 - src/transformers/models/gptj/modeling_gptj.py | 2 - src/transformers/models/groupvit/__init__.py | 1 - src/transformers/models/hubert/__init__.py | 1 - src/transformers/models/idefics/__init__.py | 1 - .../idefics/image_processing_idefics.py | 5 +- .../idefics2/image_processing_idefics2.py | 5 +- .../idefics3/image_processing_idefics3.py | 5 +- .../imagegpt/image_processing_imagegpt.py | 5 +- .../image_processing_instructblipvideo.py | 5 +- .../models/janus/image_processing_janus.py | 5 +- .../models/jetmoe/modeling_jetmoe.py | 2 - .../kosmos2_5/image_processing_kosmos2_5.py | 5 +- src/transformers/models/layoutlm/__init__.py | 1 - .../layoutlmv2/image_processing_layoutlmv2.py | 5 +- .../models/layoutlmv3/__init__.py | 1 - .../layoutlmv3/image_processing_layoutlmv3.py | 5 +- src/transformers/models/led/__init__.py | 1 - .../models/levit/image_processing_levit.py | 5 +- .../lightglue/image_processing_lightglue.py | 5 +- src/transformers/models/llama/__init__.py | 1 - .../models/llava/image_processing_llava.py | 5 +- .../llava_next/image_processing_llava_next.py | 5 +- .../image_processing_llava_onevision.py | 5 +- .../models/longformer/__init__.py | 1 - src/transformers/models/longt5/__init__.py | 1 - .../models/longt5/configuration_longt5.py | 2 +- .../convert_longt5x_checkpoint_to_flax.py | 215 ---------------- src/transformers/models/lxmert/__init__.py | 1 - src/transformers/models/marian/__init__.py | 2 - .../image_processing_mask2former.py | 8 +- .../maskformer/image_processing_maskformer.py | 8 +- src/transformers/models/mbart/__init__.py | 2 - src/transformers/models/mistral/__init__.py | 2 - .../models/mobilebert/__init__.py | 1 - .../image_processing_mobilenet_v1.py | 5 +- .../image_processing_mobilenet_v2.py | 8 +- src/transformers/models/mobilevit/__init__.py | 1 - .../mobilevit/image_processing_mobilevit.py | 8 +- src/transformers/models/mpnet/__init__.py | 1 - src/transformers/models/mt5/__init__.py | 2 - .../models/nougat/image_processing_nougat.py | 5 +- .../oneformer/image_processing_oneformer.py | 8 +- src/transformers/models/openai/__init__.py | 1 - src/transformers/models/opt/__init__.py | 2 - .../models/ovis2/image_processing_ovis2.py | 5 +- .../models/owlv2/image_processing_owlv2.py | 5 +- .../models/owlvit/image_processing_owlvit.py | 5 +- src/transformers/models/pegasus/__init__.py | 2 - .../perceiver/image_processing_perceiver.py | 5 +- .../pix2struct/image_processing_pix2struct.py | 5 +- .../pixtral/image_processing_pixtral.py | 5 +- .../poolformer/image_processing_poolformer.py | 5 +- .../image_processing_prompt_depth_anything.py | 5 +- .../models/pvt/image_processing_pvt.py | 5 +- .../qwen2_vl/image_processing_qwen2_vl.py | 5 +- src/transformers/models/rag/__init__.py | 1 - src/transformers/models/regnet/__init__.py | 2 - src/transformers/models/rembert/__init__.py | 1 - src/transformers/models/resnet/__init__.py | 2 - src/transformers/models/roberta/__init__.py | 2 - .../models/roberta_prelayernorm/__init__.py | 2 - src/transformers/models/roformer/__init__.py | 2 - .../rt_detr/image_processing_rt_detr.py | 5 +- src/transformers/models/sam/__init__.py | 1 - src/transformers/models/segformer/__init__.py | 1 - .../segformer/image_processing_segformer.py | 5 +- .../models/seggpt/image_processing_seggpt.py | 5 +- .../models/siglip/image_processing_siglip.py | 5 +- .../siglip2/image_processing_siglip2.py | 5 +- .../smolvlm/image_processing_smolvlm.py | 5 +- .../models/speech_encoder_decoder/__init__.py | 1 - .../models/speech_to_text/__init__.py | 1 - .../superglue/image_processing_superglue.py | 5 +- .../superpoint/image_processing_superpoint.py | 5 +- .../models/swiftformer/__init__.py | 1 - src/transformers/models/swin/__init__.py | 1 - .../swin2sr/image_processing_swin2sr.py | 5 +- src/transformers/models/t5/__init__.py | 2 - .../t5/convert_t5x_checkpoint_to_flax.py | 235 ------------------ src/transformers/models/tapas/__init__.py | 1 - .../textnet/image_processing_textnet.py | 5 +- .../models/tvp/image_processing_tvp.py | 5 +- .../image_processing_video_llava.py | 5 +- .../videomae/image_processing_videomae.py | 5 +- .../models/vilt/image_processing_vilt.py | 5 +- .../models/vision_encoder_decoder/__init__.py | 2 - .../vision_text_dual_encoder/__init__.py | 2 - .../modeling_vision_text_dual_encoder.py | 12 +- src/transformers/models/vit/__init__.py | 2 - .../models/vit/image_processing_vit.py | 5 +- src/transformers/models/vit_mae/__init__.py | 1 - .../vitmatte/image_processing_vitmatte.py | 10 +- .../vitpose/image_processing_vitpose.py | 5 +- .../models/vivit/image_processing_vivit.py | 5 +- src/transformers/models/wav2vec2/__init__.py | 2 - src/transformers/models/whisper/__init__.py | 2 - src/transformers/models/xglm/__init__.py | 2 - .../models/xglm/configuration_xglm.py | 2 +- src/transformers/models/xlm/__init__.py | 1 - .../models/xlm_roberta/__init__.py | 2 - src/transformers/models/xlnet/__init__.py | 1 - .../zoedepth/image_processing_zoedepth.py | 5 +- src/transformers/pipelines/base.py | 3 +- src/transformers/utils/__init__.py | 5 - .../aya_vision/test_modeling_aya_vision.py | 2 +- tests/models/clip/test_modeling_clip.py | 2 +- .../test_modeling_cohere2_vision.py | 2 +- tests/models/colpali/test_modeling_colpali.py | 4 +- .../deepseek_vl/test_modeling_deepseek_vl.py | 2 +- .../test_modeling_deepseek_vl_hybrid.py | 2 +- tests/models/gemma3n/test_modeling_gemma3n.py | 4 +- .../metaclip_2/test_modeling_metaclip_2.py | 2 +- .../paligemma/test_modeling_paligemma.py | 4 +- .../paligemma2/test_modeling_paligemma2.py | 4 +- tests/models/siglip/test_modeling_siglip.py | 8 +- tests/models/siglip2/test_modeling_siglip2.py | 8 +- tests/utils/test_hub_utils.py | 6 +- utils/add_pipeline_model_mapping_to_test.py | 7 +- utils/check_docstrings.py | 8 +- utils/check_inits.py | 1 - utils/check_model_tester.py | 4 - utils/get_test_info.py | 9 +- utils/models_to_deprecate.py | 4 - utils/not_doctested.txt | 2 - utils/tests_fetcher.py | 6 +- utils/update_tiny_models.py | 16 +- 200 files changed, 134 insertions(+), 1011 deletions(-) delete mode 100644 src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py delete mode 100644 src/transformers/models/t5/convert_t5x_checkpoint_to_flax.py diff --git a/examples/modular-transformers/image_processing_new_imgproc_model.py b/examples/modular-transformers/image_processing_new_imgproc_model.py index 7dae62f883f8..cd521a0f606d 100644 --- a/examples/modular-transformers/image_processing_new_imgproc_model.py +++ b/examples/modular-transformers/image_processing_new_imgproc_model.py @@ -223,10 +223,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/setup.py b/setup.py index f0f1598f4971..d6e69d3b83c5 100644 --- a/setup.py +++ b/setup.py @@ -109,7 +109,6 @@ "faiss-cpu", "fastapi", "filelock", - "flax>=0.4.1,<=0.7.0", "ftfy", "fugashi>=1.0", "GitPython<3.1.19", @@ -118,8 +117,6 @@ "huggingface-hub>=0.34.0,<1.0", "importlib_metadata", "ipadic>=1.0.0,<2.0", - "jax>=0.4.1,<=0.4.13", - "jaxlib>=0.4.1,<=0.4.13", "jinja2>=3.1.0", "kenlm", "kernels>=0.6.1,<=0.9", @@ -167,7 +164,7 @@ "sagemaker>=2.31.0", "schedulefree>=1.2.6", "scikit-learn", - "scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`) + "scipy", "sentencepiece>=0.1.91,!=0.1.92", "sigopt", "starlette", diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 71fe0ae4ad48..6c44c89239a2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -208,8 +208,6 @@ "PYTORCH_PRETRAINED_BERT_CACHE", "PYTORCH_TRANSFORMERS_CACHE", "SPIECE_UNDERLINE", - "TF2_WEIGHTS_NAME", - "TF_WEIGHTS_NAME", "TRANSFORMERS_CACHE", "WEIGHTS_NAME", "TensorType", @@ -748,8 +746,6 @@ from .utils import PYTORCH_PRETRAINED_BERT_CACHE as PYTORCH_PRETRAINED_BERT_CACHE from .utils import PYTORCH_TRANSFORMERS_CACHE as PYTORCH_TRANSFORMERS_CACHE from .utils import SPIECE_UNDERLINE as SPIECE_UNDERLINE - from .utils import TF2_WEIGHTS_NAME as TF2_WEIGHTS_NAME - from .utils import TF_WEIGHTS_NAME as TF_WEIGHTS_NAME from .utils import TRANSFORMERS_CACHE as TRANSFORMERS_CACHE from .utils import WEIGHTS_NAME as WEIGHTS_NAME from .utils import TensorType as TensorType diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index b9b70d256e8b..28a9f84b92a8 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -18,7 +18,6 @@ "faiss-cpu": "faiss-cpu", "fastapi": "fastapi", "filelock": "filelock", - "flax": "flax>=0.4.1,<=0.7.0", "ftfy": "ftfy", "fugashi": "fugashi>=1.0", "GitPython": "GitPython<3.1.19", @@ -27,8 +26,6 @@ "huggingface-hub": "huggingface-hub>=0.34.0,<1.0", "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", - "jax": "jax>=0.4.1,<=0.4.13", - "jaxlib": "jaxlib>=0.4.1,<=0.4.13", "jinja2": "jinja2>=3.1.0", "kenlm": "kenlm", "kernels": "kernels>=0.6.1,<=0.9", @@ -73,7 +70,7 @@ "sagemaker": "sagemaker>=2.31.0", "schedulefree": "schedulefree>=1.2.6", "scikit-learn": "scikit-learn", - "scipy": "scipy<1.13.0", + "scipy": "scipy", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "sigopt": "sigopt", "starlette": "starlette", diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 65137cf0634f..91d7974b55c1 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -31,7 +31,6 @@ ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_VALUES, FEATURE_EXTRACTOR_NAME, - FLAX_WEIGHTS_NAME, HF_MODULES_CACHE, HUGGINGFACE_CO_PREFIX, HUGGINGFACE_CO_RESOLVE_ENDPOINT, @@ -42,8 +41,6 @@ S3_BUCKET_PREFIX, SENTENCEPIECE_UNDERLINE, SPIECE_UNDERLINE, - TF2_WEIGHTS_NAME, - TF_WEIGHTS_NAME, TORCH_FX_REQUIRED_VERSION, TRANSFORMERS_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c4cc812f759a..8a0270567a29 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1163,8 +1163,7 @@ def _get_resolved_checkpoint_files( name="Thread-auto_conversion", ).start() else: - # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. - # We try those to give a helpful error message. + # Otherwise, no PyTorch file was found has_file_kwargs = { "revision": revision, "proxies": proxies, diff --git a/src/transformers/models/albert/__init__.py b/src/transformers/models/albert/__init__.py index 57b5747909e0..ac2cf362ebf2 100644 --- a/src/transformers/models/albert/__init__.py +++ b/src/transformers/models/albert/__init__.py @@ -20,8 +20,6 @@ if TYPE_CHECKING: from .configuration_albert import * from .modeling_albert import * - from .modeling_flax_albert import * - from .modeling_tf_albert import * from .tokenization_albert import * from .tokenization_albert_fast import * else: diff --git a/src/transformers/models/aria/image_processing_aria.py b/src/transformers/models/aria/image_processing_aria.py index 4fc2fcf7ec6b..f3f57b3d53c2 100644 --- a/src/transformers/models/aria/image_processing_aria.py +++ b/src/transformers/models/aria/image_processing_aria.py @@ -232,10 +232,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_normalize=do_normalize, diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index a626d2cd4b82..405c3d21dadb 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -615,10 +615,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_normalize=do_normalize, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 34a6ae1e5c2e..6b86884b3b7b 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -23,8 +23,6 @@ from .feature_extraction_auto import * from .image_processing_auto import * from .modeling_auto import * - from .modeling_flax_auto import * - from .modeling_tf_auto import * from .processing_auto import * from .tokenization_auto import * from .video_processing_auto import * diff --git a/src/transformers/models/bart/__init__.py b/src/transformers/models/bart/__init__.py index 8f4c713f4698..d268fb7d2b86 100644 --- a/src/transformers/models/bart/__init__.py +++ b/src/transformers/models/bart/__init__.py @@ -20,8 +20,6 @@ if TYPE_CHECKING: from .configuration_bart import * from .modeling_bart import * - from .modeling_flax_bart import * - from .modeling_tf_bart import * from .tokenization_bart import * from .tokenization_bart_fast import * else: diff --git a/src/transformers/models/beit/__init__.py b/src/transformers/models/beit/__init__.py index 3f412a350068..66dcfe1e56f7 100644 --- a/src/transformers/models/beit/__init__.py +++ b/src/transformers/models/beit/__init__.py @@ -23,7 +23,6 @@ from .image_processing_beit import * from .image_processing_beit_fast import * from .modeling_beit import * - from .modeling_flax_beit import * else: import sys diff --git a/src/transformers/models/beit/image_processing_beit.py b/src/transformers/models/beit/image_processing_beit.py index 0029480e46d6..984eac3bf67e 100644 --- a/src/transformers/models/beit/image_processing_beit.py +++ b/src/transformers/models/beit/image_processing_beit.py @@ -395,14 +395,10 @@ def preprocess( if segmentation_maps is not None and not valid_images(segmentation_maps): raise ValueError( - "Invalid segmentation_maps type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." + "Invalid segmentation_maps type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor" ) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/bert/__init__.py b/src/transformers/models/bert/__init__.py index 2ef22794dde2..b78228a591aa 100644 --- a/src/transformers/models/bert/__init__.py +++ b/src/transformers/models/bert/__init__.py @@ -20,8 +20,6 @@ if TYPE_CHECKING: from .configuration_bert import * from .modeling_bert import * - from .modeling_flax_bert import * - from .modeling_tf_bert import * from .tokenization_bert import * from .tokenization_bert_fast import * from .tokenization_bert_tf import * diff --git a/src/transformers/models/big_bird/__init__.py b/src/transformers/models/big_bird/__init__.py index 87419e69e5c7..e9bc0f08af3e 100644 --- a/src/transformers/models/big_bird/__init__.py +++ b/src/transformers/models/big_bird/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_big_bird import * from .modeling_big_bird import * - from .modeling_flax_big_bird import * from .tokenization_big_bird import * from .tokenization_big_bird_fast import * else: diff --git a/src/transformers/models/bit/image_processing_bit.py b/src/transformers/models/bit/image_processing_bit.py index 2c24f3f1b969..983149fea574 100644 --- a/src/transformers/models/bit/image_processing_bit.py +++ b/src/transformers/models/bit/image_processing_bit.py @@ -257,10 +257,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/blenderbot/__init__.py b/src/transformers/models/blenderbot/__init__.py index 76ece6853b38..6e728fd0914a 100644 --- a/src/transformers/models/blenderbot/__init__.py +++ b/src/transformers/models/blenderbot/__init__.py @@ -20,8 +20,6 @@ if TYPE_CHECKING: from .configuration_blenderbot import * from .modeling_blenderbot import * - from .modeling_flax_blenderbot import * - from .modeling_tf_blenderbot import * from .tokenization_blenderbot import * from .tokenization_blenderbot_fast import * else: diff --git a/src/transformers/models/blenderbot_small/__init__.py b/src/transformers/models/blenderbot_small/__init__.py index 075d0070e4c4..7f08df82e757 100644 --- a/src/transformers/models/blenderbot_small/__init__.py +++ b/src/transformers/models/blenderbot_small/__init__.py @@ -20,8 +20,6 @@ if TYPE_CHECKING: from .configuration_blenderbot_small import * from .modeling_blenderbot_small import * - from .modeling_flax_blenderbot_small import * - from .modeling_tf_blenderbot_small import * from .tokenization_blenderbot_small import * from .tokenization_blenderbot_small_fast import * else: diff --git a/src/transformers/models/blip/__init__.py b/src/transformers/models/blip/__init__.py index 952de2f855a7..c16593d7ce17 100644 --- a/src/transformers/models/blip/__init__.py +++ b/src/transformers/models/blip/__init__.py @@ -23,8 +23,6 @@ from .image_processing_blip_fast import * from .modeling_blip import * from .modeling_blip_text import * - from .modeling_tf_blip import * - from .modeling_tf_blip_text import * from .processing_blip import * else: import sys diff --git a/src/transformers/models/blip/image_processing_blip.py b/src/transformers/models/blip/image_processing_blip.py index ca04b75583b0..0efc3c5d1eb3 100644 --- a/src/transformers/models/blip/image_processing_blip.py +++ b/src/transformers/models/blip/image_processing_blip.py @@ -233,10 +233,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/bloom/__init__.py b/src/transformers/models/bloom/__init__.py index 72d1d6e6ca47..4a938fd80b25 100644 --- a/src/transformers/models/bloom/__init__.py +++ b/src/transformers/models/bloom/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_bloom import * from .modeling_bloom import * - from .modeling_flax_bloom import * from .tokenization_bloom_fast import * else: import sys diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower.py b/src/transformers/models/bridgetower/image_processing_bridgetower.py index b681292ae6d3..75b4e2b4238c 100644 --- a/src/transformers/models/bridgetower/image_processing_bridgetower.py +++ b/src/transformers/models/bridgetower/image_processing_bridgetower.py @@ -465,10 +465,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") # Here, crop_size is used only if it is set, else size will be used. validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/camembert/__init__.py b/src/transformers/models/camembert/__init__.py index 9d90f64de97f..a3a9c395eb5b 100644 --- a/src/transformers/models/camembert/__init__.py +++ b/src/transformers/models/camembert/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_camembert import * from .modeling_camembert import * - from .modeling_tf_camembert import * from .tokenization_camembert import * from .tokenization_camembert_fast import * else: diff --git a/src/transformers/models/chameleon/image_processing_chameleon.py b/src/transformers/models/chameleon/image_processing_chameleon.py index 4e9b2e1c1755..484ce53e729c 100644 --- a/src/transformers/models/chameleon/image_processing_chameleon.py +++ b/src/transformers/models/chameleon/image_processing_chameleon.py @@ -250,10 +250,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py index 3bb47ffe97b3..1ada2c715669 100644 --- a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py +++ b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py @@ -251,10 +251,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/clip/__init__.py b/src/transformers/models/clip/__init__.py index 18a4db32e994..36fb3521a93e 100644 --- a/src/transformers/models/clip/__init__.py +++ b/src/transformers/models/clip/__init__.py @@ -23,8 +23,6 @@ from .image_processing_clip import * from .image_processing_clip_fast import * from .modeling_clip import * - from .modeling_flax_clip import * - from .modeling_tf_clip import * from .processing_clip import * from .tokenization_clip import * from .tokenization_clip_fast import * diff --git a/src/transformers/models/clip/image_processing_clip.py b/src/transformers/models/clip/image_processing_clip.py index 08b27680c8e9..ca5e00579f68 100644 --- a/src/transformers/models/clip/image_processing_clip.py +++ b/src/transformers/models/clip/image_processing_clip.py @@ -287,10 +287,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 887b400b4799..aae404d1e7f3 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -294,8 +294,6 @@ def __init__(self, *inputs, **kwargs): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear,)): - # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/convbert/__init__.py b/src/transformers/models/convbert/__init__.py index 670a7d6f4764..20999ba510da 100644 --- a/src/transformers/models/convbert/__init__.py +++ b/src/transformers/models/convbert/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_convbert import * from .modeling_convbert import * - from .modeling_tf_convbert import * from .tokenization_convbert import * from .tokenization_convbert_fast import * else: diff --git a/src/transformers/models/convnext/__init__.py b/src/transformers/models/convnext/__init__.py index e2d826745f5b..92f604100822 100644 --- a/src/transformers/models/convnext/__init__.py +++ b/src/transformers/models/convnext/__init__.py @@ -23,7 +23,6 @@ from .image_processing_convnext import * from .image_processing_convnext_fast import * from .modeling_convnext import * - from .modeling_tf_convnext import * else: import sys diff --git a/src/transformers/models/convnext/image_processing_convnext.py b/src/transformers/models/convnext/image_processing_convnext.py index 0ac9a8af06e0..ae0be69a5621 100644 --- a/src/transformers/models/convnext/image_processing_convnext.py +++ b/src/transformers/models/convnext/image_processing_convnext.py @@ -263,10 +263,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/convnextv2/__init__.py b/src/transformers/models/convnextv2/__init__.py index 0fd1293963b2..9e02170eceae 100644 --- a/src/transformers/models/convnextv2/__init__.py +++ b/src/transformers/models/convnextv2/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_convnextv2 import * from .modeling_convnextv2 import * - from .modeling_tf_convnextv2 import * else: import sys diff --git a/src/transformers/models/ctrl/__init__.py b/src/transformers/models/ctrl/__init__.py index ea62163babef..93f27ba0710e 100644 --- a/src/transformers/models/ctrl/__init__.py +++ b/src/transformers/models/ctrl/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_ctrl import * from .modeling_ctrl import * - from .modeling_tf_ctrl import * from .tokenization_ctrl import * else: import sys diff --git a/src/transformers/models/cvt/__init__.py b/src/transformers/models/cvt/__init__.py index 756aded9e6ad..08a67f82b411 100644 --- a/src/transformers/models/cvt/__init__.py +++ b/src/transformers/models/cvt/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_cvt import * from .modeling_cvt import * - from .modeling_tf_cvt import * else: import sys diff --git a/src/transformers/models/data2vec/__init__.py b/src/transformers/models/data2vec/__init__.py index 7000ac3d353b..4fcf78dd606f 100644 --- a/src/transformers/models/data2vec/__init__.py +++ b/src/transformers/models/data2vec/__init__.py @@ -24,7 +24,6 @@ from .modeling_data2vec_audio import * from .modeling_data2vec_text import * from .modeling_data2vec_vision import * - from .modeling_tf_data2vec_vision import * else: import sys diff --git a/src/transformers/models/deberta/__init__.py b/src/transformers/models/deberta/__init__.py index f70972237964..ac2dbc3af259 100644 --- a/src/transformers/models/deberta/__init__.py +++ b/src/transformers/models/deberta/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_deberta import * from .modeling_deberta import * - from .modeling_tf_deberta import * from .tokenization_deberta import * from .tokenization_deberta_fast import * else: diff --git a/src/transformers/models/deberta_v2/__init__.py b/src/transformers/models/deberta_v2/__init__.py index 7c42c9c50286..929b26e60ae0 100644 --- a/src/transformers/models/deberta_v2/__init__.py +++ b/src/transformers/models/deberta_v2/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_deberta_v2 import * from .modeling_deberta_v2 import * - from .modeling_tf_deberta_v2 import * from .tokenization_deberta_v2 import * from .tokenization_deberta_v2_fast import * else: diff --git a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py index dca5eeb296ee..9d3d9a408a00 100644 --- a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py @@ -281,10 +281,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py index 0ac9602eb72c..d3d5a7e3e542 100644 --- a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py @@ -331,10 +331,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py index 3cc62d7449e5..e9808b02ce34 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py @@ -643,10 +643,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/deit/__init__.py b/src/transformers/models/deit/__init__.py index 98236a86d7a1..ef3e5149fe55 100644 --- a/src/transformers/models/deit/__init__.py +++ b/src/transformers/models/deit/__init__.py @@ -23,7 +23,6 @@ from .image_processing_deit import * from .image_processing_deit_fast import * from .modeling_deit import * - from .modeling_tf_deit import * else: import sys diff --git a/src/transformers/models/deit/image_processing_deit.py b/src/transformers/models/deit/image_processing_deit.py index dbb83412c563..795c872c62e5 100644 --- a/src/transformers/models/deit/image_processing_deit.py +++ b/src/transformers/models/deit/image_processing_deit.py @@ -241,10 +241,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/deprecated/efficientformer/__init__.py b/src/transformers/models/deprecated/efficientformer/__init__.py index db3d0a634051..e43cb0e61df1 100644 --- a/src/transformers/models/deprecated/efficientformer/__init__.py +++ b/src/transformers/models/deprecated/efficientformer/__init__.py @@ -21,7 +21,6 @@ from .configuration_efficientformer import * from .image_processing_efficientformer import * from .modeling_efficientformer import * - from .modeling_tf_efficientformer import * else: import sys diff --git a/src/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py b/src/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py index ec9e4f6f3695..a8dcedea620a 100644 --- a/src/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py +++ b/src/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py @@ -259,10 +259,7 @@ def preprocess( images = [images] if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/deprecated/mega/modeling_mega.py b/src/transformers/models/deprecated/mega/modeling_mega.py index c237afee9a33..02e41e91b24c 100644 --- a/src/transformers/models/deprecated/mega/modeling_mega.py +++ b/src/transformers/models/deprecated/mega/modeling_mega.py @@ -154,8 +154,6 @@ def __init__(self, config: MegaConfig): self.sine, self.cosine = MegaRotaryRelativePositionalBias.get_sinusoid_embeddings( config.max_positions, self.embed_dim ) - # alpha and beta parameters for the rotary bias; beta renamed to b_param to avoid clashes with tf/flax weight handling - # in loading pretrained weights self.alpha = nn.Parameter(torch.Tensor(1, self.embed_dim)) self.b_param = nn.Parameter(torch.Tensor(1, self.embed_dim)) self.register_buffer("_float_tensor", torch.FloatTensor([0.0])) diff --git a/src/transformers/models/deprecated/transfo_xl/__init__.py b/src/transformers/models/deprecated/transfo_xl/__init__.py index 0ac9a2cbf476..9bd3dd7b8838 100644 --- a/src/transformers/models/deprecated/transfo_xl/__init__.py +++ b/src/transformers/models/deprecated/transfo_xl/__init__.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from .configuration_transfo_xl import * - from .modeling_tf_transfo_xl import * from .modeling_transfo_xl import * from .tokenization_transfo_xl import * else: diff --git a/src/transformers/models/deprecated/tvlt/image_processing_tvlt.py b/src/transformers/models/deprecated/tvlt/image_processing_tvlt.py index 224a35eb0e79..19b5cddb246b 100644 --- a/src/transformers/models/deprecated/tvlt/image_processing_tvlt.py +++ b/src/transformers/models/deprecated/tvlt/image_processing_tvlt.py @@ -380,8 +380,7 @@ def preprocess( if not valid_images(videos): raise ValueError( - "Invalid image or video type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." + "Invalid image or video type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor" ) videos = make_batched(videos) diff --git a/src/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py b/src/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py index 27b1f83a7ba8..662382be43df 100644 --- a/src/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py +++ b/src/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py @@ -276,10 +276,7 @@ def preprocess( validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/depth_pro/image_processing_depth_pro.py b/src/transformers/models/depth_pro/image_processing_depth_pro.py index 2b81c95db818..5b76d8cbc8e6 100644 --- a/src/transformers/models/depth_pro/image_processing_depth_pro.py +++ b/src/transformers/models/depth_pro/image_processing_depth_pro.py @@ -259,10 +259,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") self._validate_input_arguments( do_resize=do_resize, size=size, diff --git a/src/transformers/models/dinov2/__init__.py b/src/transformers/models/dinov2/__init__.py index 3cc316957eac..002634ed4b49 100644 --- a/src/transformers/models/dinov2/__init__.py +++ b/src/transformers/models/dinov2/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_dinov2 import * from .modeling_dinov2 import * - from .modeling_flax_dinov2 import * else: import sys diff --git a/src/transformers/models/distilbert/__init__.py b/src/transformers/models/distilbert/__init__.py index 4d6fae2e0236..094524ab267f 100644 --- a/src/transformers/models/distilbert/__init__.py +++ b/src/transformers/models/distilbert/__init__.py @@ -20,8 +20,6 @@ if TYPE_CHECKING: from .configuration_distilbert import * from .modeling_distilbert import * - from .modeling_flax_distilbert import * - from .modeling_tf_distilbert import * from .tokenization_distilbert import * from .tokenization_distilbert_fast import * else: diff --git a/src/transformers/models/donut/image_processing_donut.py b/src/transformers/models/donut/image_processing_donut.py index 75bcb9cda994..f49cc964080d 100644 --- a/src/transformers/models/donut/image_processing_donut.py +++ b/src/transformers/models/donut/image_processing_donut.py @@ -396,10 +396,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/dpr/__init__.py b/src/transformers/models/dpr/__init__.py index 9aeadbeaf416..386b47bb2ecd 100644 --- a/src/transformers/models/dpr/__init__.py +++ b/src/transformers/models/dpr/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_dpr import * from .modeling_dpr import * - from .modeling_tf_dpr import * from .tokenization_dpr import * from .tokenization_dpr_fast import * else: diff --git a/src/transformers/models/dpt/image_processing_dpt.py b/src/transformers/models/dpt/image_processing_dpt.py index 70440017e518..0ec3eaed1c43 100644 --- a/src/transformers/models/dpt/image_processing_dpt.py +++ b/src/transformers/models/dpt/image_processing_dpt.py @@ -529,10 +529,7 @@ def preprocess( segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/efficientloftr/image_processing_efficientloftr.py b/src/transformers/models/efficientloftr/image_processing_efficientloftr.py index 45f28220ae84..5b87278683ac 100644 --- a/src/transformers/models/efficientloftr/image_processing_efficientloftr.py +++ b/src/transformers/models/efficientloftr/image_processing_efficientloftr.py @@ -287,10 +287,7 @@ def preprocess( images = validate_and_format_image_pairs(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_resize=do_resize, diff --git a/src/transformers/models/efficientnet/image_processing_efficientnet.py b/src/transformers/models/efficientnet/image_processing_efficientnet.py index 5331f4a4dea0..eaad420b31f8 100644 --- a/src/transformers/models/efficientnet/image_processing_efficientnet.py +++ b/src/transformers/models/efficientnet/image_processing_efficientnet.py @@ -296,10 +296,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/electra/__init__.py b/src/transformers/models/electra/__init__.py index a78ed5c42aea..506212b561e1 100644 --- a/src/transformers/models/electra/__init__.py +++ b/src/transformers/models/electra/__init__.py @@ -20,8 +20,6 @@ if TYPE_CHECKING: from .configuration_electra import * from .modeling_electra import * - from .modeling_flax_electra import * - from .modeling_tf_electra import * from .tokenization_electra import * from .tokenization_electra_fast import * else: diff --git a/src/transformers/models/emu3/image_processing_emu3.py b/src/transformers/models/emu3/image_processing_emu3.py index be4decd410dc..c46dce41f529 100644 --- a/src/transformers/models/emu3/image_processing_emu3.py +++ b/src/transformers/models/emu3/image_processing_emu3.py @@ -360,10 +360,7 @@ def preprocess( images = make_nested_list_of_images(images) if images is not None and not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( rescale_factor=rescale_factor, diff --git a/src/transformers/models/encoder_decoder/__init__.py b/src/transformers/models/encoder_decoder/__init__.py index c786feb9213f..b1cde1442a13 100644 --- a/src/transformers/models/encoder_decoder/__init__.py +++ b/src/transformers/models/encoder_decoder/__init__.py @@ -20,8 +20,6 @@ if TYPE_CHECKING: from .configuration_encoder_decoder import * from .modeling_encoder_decoder import * - from .modeling_flax_encoder_decoder import * - from .modeling_tf_encoder_decoder import * else: import sys diff --git a/src/transformers/models/eomt/image_processing_eomt.py b/src/transformers/models/eomt/image_processing_eomt.py index 83bc70521019..4fbd308da336 100644 --- a/src/transformers/models/eomt/image_processing_eomt.py +++ b/src/transformers/models/eomt/image_processing_eomt.py @@ -581,10 +581,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/esm/__init__.py b/src/transformers/models/esm/__init__.py index 8eac54d6ddcb..e308c53e9a3d 100644 --- a/src/transformers/models/esm/__init__.py +++ b/src/transformers/models/esm/__init__.py @@ -21,7 +21,6 @@ from .configuration_esm import * from .modeling_esm import * from .modeling_esmfold import * - from .modeling_tf_esm import * from .tokenization_esm import * else: import sys diff --git a/src/transformers/models/esm/openfold_utils/tensor_utils.py b/src/transformers/models/esm/openfold_utils/tensor_utils.py index 449c810aed3f..c776f89659b6 100644 --- a/src/transformers/models/esm/openfold_utils/tensor_utils.py +++ b/src/transformers/models/esm/openfold_utils/tensor_utils.py @@ -93,7 +93,6 @@ def batched_gather(data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batc T = TypeVar("T") -# With tree_map, a poor man's JAX tree_map def dict_map( fn: Callable[[T], Any], dic: dict[Any, Union[dict, list, tuple, T]], leaf_type: type[T] ) -> dict[Any, Union[dict, list, tuple, Any]]: diff --git a/src/transformers/models/flaubert/__init__.py b/src/transformers/models/flaubert/__init__.py index e981d9cbcb1e..e418a0f74381 100644 --- a/src/transformers/models/flaubert/__init__.py +++ b/src/transformers/models/flaubert/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_flaubert import * from .modeling_flaubert import * - from .modeling_tf_flaubert import * from .tokenization_flaubert import * else: import sys diff --git a/src/transformers/models/flava/image_processing_flava.py b/src/transformers/models/flava/image_processing_flava.py index 1e3fecfd9740..9d67ac841124 100644 --- a/src/transformers/models/flava/image_processing_flava.py +++ b/src/transformers/models/flava/image_processing_flava.py @@ -638,10 +638,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") processed_images = [ self._preprocess_image( diff --git a/src/transformers/models/funnel/__init__.py b/src/transformers/models/funnel/__init__.py index e4e0587ce32f..1a75ee7e2e61 100644 --- a/src/transformers/models/funnel/__init__.py +++ b/src/transformers/models/funnel/__init__.py @@ -21,7 +21,6 @@ from .configuration_funnel import * from .convert_funnel_original_tf_checkpoint_to_pytorch import * from .modeling_funnel import * - from .modeling_tf_funnel import * from .tokenization_funnel import * from .tokenization_funnel_fast import * else: diff --git a/src/transformers/models/gemma/__init__.py b/src/transformers/models/gemma/__init__.py index 65fb1ca5edef..80c8d30760c4 100644 --- a/src/transformers/models/gemma/__init__.py +++ b/src/transformers/models/gemma/__init__.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from .configuration_gemma import * - from .modeling_flax_gemma import * from .modeling_gemma import * from .tokenization_gemma import * from .tokenization_gemma_fast import * diff --git a/src/transformers/models/gemma3/image_processing_gemma3.py b/src/transformers/models/gemma3/image_processing_gemma3.py index 02db120c8414..efa65a6d2bf2 100644 --- a/src/transformers/models/gemma3/image_processing_gemma3.py +++ b/src/transformers/models/gemma3/image_processing_gemma3.py @@ -336,10 +336,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py index 62e3fb3878f7..7dcc4e2c5ca8 100644 --- a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py +++ b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py @@ -296,7 +296,7 @@ def __call__( pad_to_multiple_of (`int`, *optional*, defaults to 128): When padding, pad to a multiple of this value. The default value is defined for optimal TPU support. return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`): - The type of tensors to return (e.g., NumPy, Torch, JAX, TensorFlow). + The type of tensors to return (e.g., NumPy, or Torch). return_attention_mask (`bool`, *optional*, defaults to `True`): Whether to return the attention mask for the generated MEL spectrograms. """ diff --git a/src/transformers/models/glm4v/image_processing_glm4v.py b/src/transformers/models/glm4v/image_processing_glm4v.py index ad6549826fe2..e35699005116 100644 --- a/src/transformers/models/glm4v/image_processing_glm4v.py +++ b/src/transformers/models/glm4v/image_processing_glm4v.py @@ -391,10 +391,7 @@ def preprocess( images = make_flat_list_of_images(images) if images is not None and not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( rescale_factor=rescale_factor, diff --git a/src/transformers/models/glpn/image_processing_glpn.py b/src/transformers/models/glpn/image_processing_glpn.py index 2e2d82a23322..35306eabc8d5 100644 --- a/src/transformers/models/glpn/image_processing_glpn.py +++ b/src/transformers/models/glpn/image_processing_glpn.py @@ -187,10 +187,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") # Here, the rescale() method uses a constant rescale_factor. It does not need to be validated # with a rescale_factor. diff --git a/src/transformers/models/got_ocr2/image_processing_got_ocr2.py b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py index 6880cfe208a0..43bf8b520ffa 100644 --- a/src/transformers/models/got_ocr2/image_processing_got_ocr2.py +++ b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py @@ -342,10 +342,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/gpt2/__init__.py b/src/transformers/models/gpt2/__init__.py index f01899e668e3..58c4f4e012f5 100644 --- a/src/transformers/models/gpt2/__init__.py +++ b/src/transformers/models/gpt2/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_gpt2 import * - from .modeling_flax_gpt2 import * from .modeling_gpt2 import * - from .modeling_tf_gpt2 import * from .tokenization_gpt2 import * from .tokenization_gpt2_fast import * from .tokenization_gpt2_tf import * diff --git a/src/transformers/models/gpt_neo/__init__.py b/src/transformers/models/gpt_neo/__init__.py index 578577f22882..242a20d00d6d 100644 --- a/src/transformers/models/gpt_neo/__init__.py +++ b/src/transformers/models/gpt_neo/__init__.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from .configuration_gpt_neo import * - from .modeling_flax_gpt_neo import * from .modeling_gpt_neo import * else: import sys diff --git a/src/transformers/models/gptj/__init__.py b/src/transformers/models/gptj/__init__.py index 84d99fda2e69..a814910a8885 100644 --- a/src/transformers/models/gptj/__init__.py +++ b/src/transformers/models/gptj/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_gptj import * - from .modeling_flax_gptj import * from .modeling_gptj import * - from .modeling_tf_gptj import * else: import sys diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index cb6a4f579c52..cf63907dc6bf 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -480,8 +480,6 @@ def __init__(self, *inputs, **kwargs): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear,)): - # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/groupvit/__init__.py b/src/transformers/models/groupvit/__init__.py index ab7fa27d09d1..10c315e28015 100644 --- a/src/transformers/models/groupvit/__init__.py +++ b/src/transformers/models/groupvit/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_groupvit import * from .modeling_groupvit import * - from .modeling_tf_groupvit import * else: import sys diff --git a/src/transformers/models/hubert/__init__.py b/src/transformers/models/hubert/__init__.py index d975dabc689a..25d366620f0b 100644 --- a/src/transformers/models/hubert/__init__.py +++ b/src/transformers/models/hubert/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_hubert import * from .modeling_hubert import * - from .modeling_tf_hubert import * else: import sys diff --git a/src/transformers/models/idefics/__init__.py b/src/transformers/models/idefics/__init__.py index 4adb66825445..fedac647008b 100644 --- a/src/transformers/models/idefics/__init__.py +++ b/src/transformers/models/idefics/__init__.py @@ -21,7 +21,6 @@ from .configuration_idefics import * from .image_processing_idefics import * from .modeling_idefics import * - from .modeling_tf_idefics import * from .processing_idefics import * else: import sys diff --git a/src/transformers/models/idefics/image_processing_idefics.py b/src/transformers/models/idefics/image_processing_idefics.py index fe9085331cde..6ef5b39afeeb 100644 --- a/src/transformers/models/idefics/image_processing_idefics.py +++ b/src/transformers/models/idefics/image_processing_idefics.py @@ -155,10 +155,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") # For training a user needs to pass their own set of transforms as a Callable. # For reference this is what was used in the original IDEFICS training: diff --git a/src/transformers/models/idefics2/image_processing_idefics2.py b/src/transformers/models/idefics2/image_processing_idefics2.py index 22befb6ceaed..15a04a887e87 100644 --- a/src/transformers/models/idefics2/image_processing_idefics2.py +++ b/src/transformers/models/idefics2/image_processing_idefics2.py @@ -472,10 +472,7 @@ def preprocess( images_list = make_nested_list_of_images(images) if not valid_images(images_list[0]): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/idefics3/image_processing_idefics3.py b/src/transformers/models/idefics3/image_processing_idefics3.py index e298f0890c99..c7526f30993a 100644 --- a/src/transformers/models/idefics3/image_processing_idefics3.py +++ b/src/transformers/models/idefics3/image_processing_idefics3.py @@ -689,10 +689,7 @@ def preprocess( images_list = make_nested_list_of_images(images) if not valid_images(images_list[0]): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/imagegpt/image_processing_imagegpt.py b/src/transformers/models/imagegpt/image_processing_imagegpt.py index 42f3a296103b..97086ed45e07 100644 --- a/src/transformers/models/imagegpt/image_processing_imagegpt.py +++ b/src/transformers/models/imagegpt/image_processing_imagegpt.py @@ -239,10 +239,7 @@ def preprocess( images = make_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") # Here, normalize() is using a constant factor to divide pixel values. # hence, the method does not need iamge_mean and image_std. diff --git a/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py index 3e1ee362aad8..ccd0d701738c 100644 --- a/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py @@ -251,10 +251,7 @@ def preprocess( ) if not valid_images(videos): - raise ValueError( - "Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") pixel_values = [ [ diff --git a/src/transformers/models/janus/image_processing_janus.py b/src/transformers/models/janus/image_processing_janus.py index b995cdcf8b92..c1f34efde71e 100644 --- a/src/transformers/models/janus/image_processing_janus.py +++ b/src/transformers/models/janus/image_processing_janus.py @@ -278,10 +278,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 0ca0a9a43669..06388d96f1e7 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -843,8 +843,6 @@ class JetMoePreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear,)): - # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py index a16a1eba626a..b768205da2a4 100644 --- a/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py +++ b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py @@ -292,10 +292,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") # PIL RGBA images are converted to RGB if do_convert_rgb: diff --git a/src/transformers/models/layoutlm/__init__.py b/src/transformers/models/layoutlm/__init__.py index 0f079c33c715..5db595015b49 100644 --- a/src/transformers/models/layoutlm/__init__.py +++ b/src/transformers/models/layoutlm/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_layoutlm import * from .modeling_layoutlm import * - from .modeling_tf_layoutlm import * from .tokenization_layoutlm import * from .tokenization_layoutlm_fast import * else: diff --git a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py index c959583ae5bc..76fc752bbeea 100644 --- a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py @@ -252,10 +252,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_resize=do_resize, size=size, diff --git a/src/transformers/models/layoutlmv3/__init__.py b/src/transformers/models/layoutlmv3/__init__.py index c776b676f603..c87afd9c58e7 100644 --- a/src/transformers/models/layoutlmv3/__init__.py +++ b/src/transformers/models/layoutlmv3/__init__.py @@ -23,7 +23,6 @@ from .image_processing_layoutlmv3 import * from .image_processing_layoutlmv3_fast import * from .modeling_layoutlmv3 import * - from .modeling_tf_layoutlmv3 import * from .processing_layoutlmv3 import * from .tokenization_layoutlmv3 import * from .tokenization_layoutlmv3_fast import * diff --git a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py index 855ae52e8075..0ce7f5ce6968 100644 --- a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py @@ -307,10 +307,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/led/__init__.py b/src/transformers/models/led/__init__.py index 786ebd36d7b8..678b3af02aaf 100644 --- a/src/transformers/models/led/__init__.py +++ b/src/transformers/models/led/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_led import * from .modeling_led import * - from .modeling_tf_led import * from .tokenization_led import * from .tokenization_led_fast import * else: diff --git a/src/transformers/models/levit/image_processing_levit.py b/src/transformers/models/levit/image_processing_levit.py index 09c23d3bad91..021c6f4aa652 100644 --- a/src/transformers/models/levit/image_processing_levit.py +++ b/src/transformers/models/levit/image_processing_levit.py @@ -256,10 +256,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/lightglue/image_processing_lightglue.py b/src/transformers/models/lightglue/image_processing_lightglue.py index 4263c3712407..855fbb12d641 100644 --- a/src/transformers/models/lightglue/image_processing_lightglue.py +++ b/src/transformers/models/lightglue/image_processing_lightglue.py @@ -288,10 +288,7 @@ def preprocess( images = validate_and_format_image_pairs(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_resize=do_resize, diff --git a/src/transformers/models/llama/__init__.py b/src/transformers/models/llama/__init__.py index 0677bb91435c..3166111744a1 100644 --- a/src/transformers/models/llama/__init__.py +++ b/src/transformers/models/llama/__init__.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from .configuration_llama import * - from .modeling_flax_llama import * from .modeling_llama import * from .tokenization_llama import * from .tokenization_llama_fast import * diff --git a/src/transformers/models/llava/image_processing_llava.py b/src/transformers/models/llava/image_processing_llava.py index a77e7649b7d0..543b22dc431f 100644 --- a/src/transformers/models/llava/image_processing_llava.py +++ b/src/transformers/models/llava/image_processing_llava.py @@ -369,10 +369,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") # we don't pass `do_pad` here since LLaVa uses a custom padding to a square validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/llava_next/image_processing_llava_next.py b/src/transformers/models/llava_next/image_processing_llava_next.py index 7b8bc4a513ef..07d8a934db21 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next.py +++ b/src/transformers/models/llava_next/image_processing_llava_next.py @@ -641,10 +641,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py index 65098512366b..d7593a5355bd 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py @@ -694,10 +694,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/longformer/__init__.py b/src/transformers/models/longformer/__init__.py index 87f53105424b..a0ee62087e76 100644 --- a/src/transformers/models/longformer/__init__.py +++ b/src/transformers/models/longformer/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_longformer import * from .modeling_longformer import * - from .modeling_tf_longformer import * from .tokenization_longformer import * from .tokenization_longformer_fast import * else: diff --git a/src/transformers/models/longt5/__init__.py b/src/transformers/models/longt5/__init__.py index 2716e62cd7b2..9821ef87bc36 100644 --- a/src/transformers/models/longt5/__init__.py +++ b/src/transformers/models/longt5/__init__.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from .configuration_longt5 import * - from .modeling_flax_longt5 import * from .modeling_longt5 import * else: import sys diff --git a/src/transformers/models/longt5/configuration_longt5.py b/src/transformers/models/longt5/configuration_longt5.py index 245e9948a1ae..b4833f4394e7 100644 --- a/src/transformers/models/longt5/configuration_longt5.py +++ b/src/transformers/models/longt5/configuration_longt5.py @@ -26,7 +26,7 @@ class LongT5Config(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`LongT5Model`] or a [`FlaxLongT5Model`]. It is + This is the configuration class to store the configuration of a [`LongT5Model`]. It is used to instantiate a LongT5 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the LongT5 [google/long-t5-local-base](https://huggingface.co/google/long-t5-local-base) architecture. diff --git a/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py b/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py deleted file mode 100644 index d99797107363..000000000000 --- a/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py +++ /dev/null @@ -1,215 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Convert T5/LongT5X checkpoints from the original repository to JAX/FLAX model. This script is an extension of -'src/transformers/models/t5/convert_t5x_checkpoint_to_flax. -""" - -import argparse - -from t5x import checkpoints - -from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM - - -def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path): - config = AutoConfig.from_pretrained(config_name) - flax_model = FlaxAutoModelForSeq2SeqLM.from_config(config=config) - t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) - - split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"] - - if config.model_type == "t5": - encoder_attn_name = "SelfAttention" - if config.model_type == "longt5" and config.encoder_attention_type == "local": - encoder_attn_name = "LocalSelfAttention" - elif config.model_type == "longt5" and config.encoder_attention_type == "transient-global": - encoder_attn_name = "TransientGlobalSelfAttention" - else: - raise ValueError( - "Given config is expected to have `model_type='t5'`, or `model_type='longt5` with `encoder_attention_type`" - " attribute with a value from ['local', 'transient-global]." - ) - - # Encoder - for layer_index in range(config.num_layers): - layer_name = f"layers_{str(layer_index)}" - - # Self-Attention - t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"] - t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"] - t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"] - t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"] - - # Global input layer norm - if config.model_type == "longt5" and config.encoder_attention_type == "transient-global": - t5x_global_layer_norm = t5x_model["target"]["encoder"][layer_name]["attention"]["T5LayerNorm_0"]["scale"] - - # Layer Normalization - t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"] - - if split_mlp_wi: - t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"] - t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"] - else: - t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"] - - t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] - - # Layer Normalization - t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"] - - # Assigning - flax_model_encoder_layer_block = flax_model.params["encoder"]["block"][str(layer_index)]["layer"] - flax_model_encoder_layer_block["0"][encoder_attn_name]["k"]["kernel"] = t5x_attention_key - flax_model_encoder_layer_block["0"][encoder_attn_name]["o"]["kernel"] = t5x_attention_out - flax_model_encoder_layer_block["0"][encoder_attn_name]["q"]["kernel"] = t5x_attention_query - flax_model_encoder_layer_block["0"][encoder_attn_name]["v"]["kernel"] = t5x_attention_value - - flax_model_encoder_layer_block["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm - - # Global input layer norm - if config.model_type == "longt5" and config.encoder_attention_type == "transient-global": - flax_model_encoder_layer_block["0"][encoder_attn_name]["global_input_layer_norm"]["weight"] = ( - t5x_global_layer_norm - ) - - if split_mlp_wi: - flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0 - flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1 - else: - flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi - - flax_model_encoder_layer_block["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo - flax_model_encoder_layer_block["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm - - flax_model.params["encoder"]["block"][str(layer_index)]["layer"] = flax_model_encoder_layer_block - - # Only for layer 0: - t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T - flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["relative_attention_bias"][ - "embedding" - ] = t5x_encoder_rel_embedding - - # Side/global relative position_bias + layer norm - if config.model_type == "longt5" and config.encoder_attention_type == "transient-global": - t5x_encoder_global_rel_embedding = t5x_model["target"]["encoder"]["side_relpos_bias"]["rel_embedding"].T - flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["global_relative_attention_bias"][ - "embedding" - ] = t5x_encoder_global_rel_embedding - - # Assigning - t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"] - flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm - - # Decoder - for layer_index in range(config.num_layers): - layer_name = f"layers_{str(layer_index)}" - - # Self-Attention - t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"] - t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"] - t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"] - t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"] - - # Layer Normalization - t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][ - "scale" - ] - - # Encoder-Decoder-Attention - t5x_enc_dec_attention_module = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"] - t5x_enc_dec_attention_key = t5x_enc_dec_attention_module["key"]["kernel"] - t5x_enc_dec_attention_out = t5x_enc_dec_attention_module["out"]["kernel"] - t5x_enc_dec_attention_query = t5x_enc_dec_attention_module["query"]["kernel"] - t5x_enc_dec_attention_value = t5x_enc_dec_attention_module["value"]["kernel"] - - # Layer Normalization - t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"] - - # MLP - if split_mlp_wi: - t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"] - t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"] - else: - t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"] - - t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] - - # Layer Normalization - tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] - - # Assigning - flax_model_decoder_layer_block = flax_model.params["decoder"]["block"][str(layer_index)]["layer"] - flax_model_decoder_layer_block["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key - flax_model_decoder_layer_block["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out - flax_model_decoder_layer_block["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query - flax_model_decoder_layer_block["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value - - flax_model_decoder_layer_block["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm - - flax_model_decoder_layer_block["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key - flax_model_decoder_layer_block["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out - flax_model_decoder_layer_block["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query - flax_model_decoder_layer_block["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value - - flax_model_decoder_layer_block["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm - - if split_mlp_wi: - flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0 - flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1 - else: - flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi - - flax_model_decoder_layer_block["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo - - flax_model_decoder_layer_block["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"] = flax_model_decoder_layer_block - - # Decoder Normalization - tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"] - flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm - - # Only for layer 0: - t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T - flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ - "embedding" - ] = t5x_decoder_rel_embedding - - # Token Embeddings - tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"] - flax_model.params["shared"]["embedding"] = tx5_token_embeddings - - # LM Head (only in v1.1 and LongT5 checkpoints) - if "logits_dense" in t5x_model["target"]["decoder"]: - flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"] - - flax_model.save_pretrained(flax_dump_folder_path) - print("T5X Model was successfully converted!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the T5X checkpoint." - ) - parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of LongT5/T5 model.") - parser.add_argument( - "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." - ) - args = parser.parse_args() - convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path) diff --git a/src/transformers/models/lxmert/__init__.py b/src/transformers/models/lxmert/__init__.py index 3ad507465039..8cde45820316 100644 --- a/src/transformers/models/lxmert/__init__.py +++ b/src/transformers/models/lxmert/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_lxmert import * from .modeling_lxmert import * - from .modeling_tf_lxmert import * from .tokenization_lxmert import * from .tokenization_lxmert_fast import * else: diff --git a/src/transformers/models/marian/__init__.py b/src/transformers/models/marian/__init__.py index 6cfabc1590f2..84afe5b372bb 100644 --- a/src/transformers/models/marian/__init__.py +++ b/src/transformers/models/marian/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_marian import * - from .modeling_flax_marian import * from .modeling_marian import * - from .modeling_tf_marian import * from .tokenization_marian import * else: import sys diff --git a/src/transformers/models/mask2former/image_processing_mask2former.py b/src/transformers/models/mask2former/image_processing_mask2former.py index bebab8b9e2da..14f75a8c414f 100644 --- a/src/transformers/models/mask2former/image_processing_mask2former.py +++ b/src/transformers/models/mask2former/image_processing_mask2former.py @@ -739,10 +739,7 @@ def preprocess( pad_size = self.pad_size if pad_size is None else pad_size if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, @@ -757,8 +754,7 @@ def preprocess( if segmentation_maps is not None and not valid_images(segmentation_maps): raise ValueError( - "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor" ) images = make_flat_list_of_images(images) diff --git a/src/transformers/models/maskformer/image_processing_maskformer.py b/src/transformers/models/maskformer/image_processing_maskformer.py index f94202f47243..f537adad22bd 100644 --- a/src/transformers/models/maskformer/image_processing_maskformer.py +++ b/src/transformers/models/maskformer/image_processing_maskformer.py @@ -742,10 +742,7 @@ def preprocess( pad_size = self.pad_size if pad_size is None else pad_size if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, @@ -760,8 +757,7 @@ def preprocess( if segmentation_maps is not None and not valid_images(segmentation_maps): raise ValueError( - "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor" ) images = make_flat_list_of_images(images) diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py index f3c5ef5767a4..0b90185d2cbe 100644 --- a/src/transformers/models/mbart/__init__.py +++ b/src/transformers/models/mbart/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_mbart import * - from .modeling_flax_mbart import * from .modeling_mbart import * - from .modeling_tf_mbart import * from .tokenization_mbart import * from .tokenization_mbart_fast import * else: diff --git a/src/transformers/models/mistral/__init__.py b/src/transformers/models/mistral/__init__.py index 18a5657cd2ec..ea17b3d67bc4 100644 --- a/src/transformers/models/mistral/__init__.py +++ b/src/transformers/models/mistral/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_mistral import * - from .modeling_flax_mistral import * from .modeling_mistral import * - from .modeling_tf_mistral import * else: import sys diff --git a/src/transformers/models/mobilebert/__init__.py b/src/transformers/models/mobilebert/__init__.py index 4ea599122ddc..0066f7f2b382 100644 --- a/src/transformers/models/mobilebert/__init__.py +++ b/src/transformers/models/mobilebert/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_mobilebert import * from .modeling_mobilebert import * - from .modeling_tf_mobilebert import * from .tokenization_mobilebert import * from .tokenization_mobilebert_fast import * else: diff --git a/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py b/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py index 897a5f2074d9..da384d40b3ed 100644 --- a/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +++ b/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py @@ -246,10 +246,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py index 3224c2665704..186dc3cf5772 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py @@ -411,15 +411,11 @@ def preprocess( segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") if segmentation_maps is not None and not valid_images(segmentation_maps): raise ValueError( - "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor" ) validate_preprocess_arguments( diff --git a/src/transformers/models/mobilevit/__init__.py b/src/transformers/models/mobilevit/__init__.py index 6750449a3eae..282e858e6798 100644 --- a/src/transformers/models/mobilevit/__init__.py +++ b/src/transformers/models/mobilevit/__init__.py @@ -23,7 +23,6 @@ from .image_processing_mobilevit import * from .image_processing_mobilevit_fast import * from .modeling_mobilevit import * - from .modeling_tf_mobilevit import * else: import sys diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit.py b/src/transformers/models/mobilevit/image_processing_mobilevit.py index 93ff9bc6a1c2..0ea7a0706cc4 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit.py @@ -410,15 +410,11 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") if segmentation_maps is not None and not valid_images(segmentation_maps): raise ValueError( - "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor" ) validate_preprocess_arguments( diff --git a/src/transformers/models/mpnet/__init__.py b/src/transformers/models/mpnet/__init__.py index 0b7abc8357cc..402cc164b979 100644 --- a/src/transformers/models/mpnet/__init__.py +++ b/src/transformers/models/mpnet/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_mpnet import * from .modeling_mpnet import * - from .modeling_tf_mpnet import * from .tokenization_mpnet import * from .tokenization_mpnet_fast import * else: diff --git a/src/transformers/models/mt5/__init__.py b/src/transformers/models/mt5/__init__.py index 444a8f8cc8e0..f04d056d6e08 100644 --- a/src/transformers/models/mt5/__init__.py +++ b/src/transformers/models/mt5/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_mt5 import * - from .modeling_flax_mt5 import * from .modeling_mt5 import * - from .modeling_tf_mt5 import * from .tokenization_mt5 import * else: import sys diff --git a/src/transformers/models/nougat/image_processing_nougat.py b/src/transformers/models/nougat/image_processing_nougat.py index 793f450484fb..9cb26feafa10 100644 --- a/src/transformers/models/nougat/image_processing_nougat.py +++ b/src/transformers/models/nougat/image_processing_nougat.py @@ -449,10 +449,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/oneformer/image_processing_oneformer.py b/src/transformers/models/oneformer/image_processing_oneformer.py index 7311e017e213..abd178926d71 100644 --- a/src/transformers/models/oneformer/image_processing_oneformer.py +++ b/src/transformers/models/oneformer/image_processing_oneformer.py @@ -700,10 +700,7 @@ def preprocess( do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, @@ -718,8 +715,7 @@ def preprocess( if segmentation_maps is not None and not valid_images(segmentation_maps): raise ValueError( - "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor" ) images = make_flat_list_of_images(images) diff --git a/src/transformers/models/openai/__init__.py b/src/transformers/models/openai/__init__.py index a07b0ab669f3..98a22135ea40 100644 --- a/src/transformers/models/openai/__init__.py +++ b/src/transformers/models/openai/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_openai import * from .modeling_openai import * - from .modeling_tf_openai import * from .tokenization_openai import * from .tokenization_openai_fast import * else: diff --git a/src/transformers/models/opt/__init__.py b/src/transformers/models/opt/__init__.py index d230de5ecadc..ecf8f8dee945 100644 --- a/src/transformers/models/opt/__init__.py +++ b/src/transformers/models/opt/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_opt import * - from .modeling_flax_opt import * from .modeling_opt import * - from .modeling_tf_opt import * else: import sys diff --git a/src/transformers/models/ovis2/image_processing_ovis2.py b/src/transformers/models/ovis2/image_processing_ovis2.py index ce776d24c7db..c235504d2d89 100644 --- a/src/transformers/models/ovis2/image_processing_ovis2.py +++ b/src/transformers/models/ovis2/image_processing_ovis2.py @@ -402,10 +402,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/owlv2/image_processing_owlv2.py b/src/transformers/models/owlv2/image_processing_owlv2.py index 19a04cdeb871..a79cc57a6c94 100644 --- a/src/transformers/models/owlv2/image_processing_owlv2.py +++ b/src/transformers/models/owlv2/image_processing_owlv2.py @@ -435,10 +435,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") # Here, pad and resize methods are different from the rest of image processors # as they don't have any resampling in resize() # or pad size in pad() (the maximum of (height, width) is taken instead). diff --git a/src/transformers/models/owlvit/image_processing_owlvit.py b/src/transformers/models/owlvit/image_processing_owlvit.py index 0af755297afd..42e3f10269b4 100644 --- a/src/transformers/models/owlvit/image_processing_owlvit.py +++ b/src/transformers/models/owlvit/image_processing_owlvit.py @@ -383,10 +383,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/pegasus/__init__.py b/src/transformers/models/pegasus/__init__.py index 4903c400f982..4070d841ea3d 100644 --- a/src/transformers/models/pegasus/__init__.py +++ b/src/transformers/models/pegasus/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_pegasus import * - from .modeling_flax_pegasus import * from .modeling_pegasus import * - from .modeling_tf_pegasus import * from .tokenization_pegasus import * from .tokenization_pegasus_fast import * else: diff --git a/src/transformers/models/perceiver/image_processing_perceiver.py b/src/transformers/models/perceiver/image_processing_perceiver.py index acc72f0ab877..376d33f8c356 100644 --- a/src/transformers/models/perceiver/image_processing_perceiver.py +++ b/src/transformers/models/perceiver/image_processing_perceiver.py @@ -287,10 +287,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/pix2struct/image_processing_pix2struct.py b/src/transformers/models/pix2struct/image_processing_pix2struct.py index 610ac36ea086..316f2021461c 100644 --- a/src/transformers/models/pix2struct/image_processing_pix2struct.py +++ b/src/transformers/models/pix2struct/image_processing_pix2struct.py @@ -402,10 +402,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") # PIL RGBA images are converted to RGB if do_convert_rgb: diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index 4c08f533e9e2..86b11cd1f61a 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -399,10 +399,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images[0]): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/poolformer/image_processing_poolformer.py b/src/transformers/models/poolformer/image_processing_poolformer.py index b8ea103a5366..ce3cd398745c 100644 --- a/src/transformers/models/poolformer/image_processing_poolformer.py +++ b/src/transformers/models/poolformer/image_processing_poolformer.py @@ -292,10 +292,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index 985fe290f689..7224aeef8612 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -366,10 +366,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/pvt/image_processing_pvt.py b/src/transformers/models/pvt/image_processing_pvt.py index c86727fff290..faec1739c811 100644 --- a/src/transformers/models/pvt/image_processing_pvt.py +++ b/src/transformers/models/pvt/image_processing_pvt.py @@ -217,10 +217,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py index 70af072d1bee..36a58d68730b 100644 --- a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py @@ -406,10 +406,7 @@ def preprocess( images = make_flat_list_of_images(images) if images is not None and not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( rescale_factor=rescale_factor, diff --git a/src/transformers/models/rag/__init__.py b/src/transformers/models/rag/__init__.py index 8a8f135ba454..ce12d1526149 100644 --- a/src/transformers/models/rag/__init__.py +++ b/src/transformers/models/rag/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_rag import * from .modeling_rag import * - from .modeling_tf_rag import * from .retrieval_rag import * from .tokenization_rag import * else: diff --git a/src/transformers/models/regnet/__init__.py b/src/transformers/models/regnet/__init__.py index cac770fdd0bc..a49c5ee7f2e4 100644 --- a/src/transformers/models/regnet/__init__.py +++ b/src/transformers/models/regnet/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_regnet import * - from .modeling_flax_regnet import * from .modeling_regnet import * - from .modeling_tf_regnet import * else: import sys diff --git a/src/transformers/models/rembert/__init__.py b/src/transformers/models/rembert/__init__.py index 38566f502ad0..23b308c7f13d 100644 --- a/src/transformers/models/rembert/__init__.py +++ b/src/transformers/models/rembert/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_rembert import * from .modeling_rembert import * - from .modeling_tf_rembert import * from .tokenization_rembert import * from .tokenization_rembert_fast import * else: diff --git a/src/transformers/models/resnet/__init__.py b/src/transformers/models/resnet/__init__.py index 625e93a25543..db16908bff31 100644 --- a/src/transformers/models/resnet/__init__.py +++ b/src/transformers/models/resnet/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_resnet import * - from .modeling_flax_resnet import * from .modeling_resnet import * - from .modeling_tf_resnet import * else: import sys diff --git a/src/transformers/models/roberta/__init__.py b/src/transformers/models/roberta/__init__.py index 9f9418d33d35..a82d4c9bc617 100644 --- a/src/transformers/models/roberta/__init__.py +++ b/src/transformers/models/roberta/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_roberta import * - from .modeling_flax_roberta import * from .modeling_roberta import * - from .modeling_tf_roberta import * from .tokenization_roberta import * from .tokenization_roberta_fast import * else: diff --git a/src/transformers/models/roberta_prelayernorm/__init__.py b/src/transformers/models/roberta_prelayernorm/__init__.py index 208878343d24..369698d84ba0 100644 --- a/src/transformers/models/roberta_prelayernorm/__init__.py +++ b/src/transformers/models/roberta_prelayernorm/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_roberta_prelayernorm import * - from .modeling_flax_roberta_prelayernorm import * from .modeling_roberta_prelayernorm import * - from .modeling_tf_roberta_prelayernorm import * else: import sys diff --git a/src/transformers/models/roformer/__init__.py b/src/transformers/models/roformer/__init__.py index 63c1c00e5723..4d1232523f8c 100644 --- a/src/transformers/models/roformer/__init__.py +++ b/src/transformers/models/roformer/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_roformer import * - from .modeling_flax_roformer import * from .modeling_roformer import * - from .modeling_tf_roformer import * from .tokenization_roformer import * from .tokenization_roformer_fast import * else: diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr.py b/src/transformers/models/rt_detr/image_processing_rt_detr.py index 4603a21095e7..cf657867a9f8 100644 --- a/src/transformers/models/rt_detr/image_processing_rt_detr.py +++ b/src/transformers/models/rt_detr/image_processing_rt_detr.py @@ -893,10 +893,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") # All transformations expect numpy arrays images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/sam/__init__.py b/src/transformers/models/sam/__init__.py index bb8a2b98e636..ac0a21f82930 100644 --- a/src/transformers/models/sam/__init__.py +++ b/src/transformers/models/sam/__init__.py @@ -22,7 +22,6 @@ from .image_processing_sam import * from .image_processing_sam_fast import * from .modeling_sam import * - from .modeling_tf_sam import * from .processing_sam import * else: import sys diff --git a/src/transformers/models/segformer/__init__.py b/src/transformers/models/segformer/__init__.py index 81655dfa7048..c9b88d1a98c6 100644 --- a/src/transformers/models/segformer/__init__.py +++ b/src/transformers/models/segformer/__init__.py @@ -23,7 +23,6 @@ from .image_processing_segformer import * from .image_processing_segformer_fast import * from .modeling_segformer import * - from .modeling_tf_segformer import * else: import sys diff --git a/src/transformers/models/segformer/image_processing_segformer.py b/src/transformers/models/segformer/image_processing_segformer.py index 15c1b899e7dd..0894c352de8b 100644 --- a/src/transformers/models/segformer/image_processing_segformer.py +++ b/src/transformers/models/segformer/image_processing_segformer.py @@ -372,10 +372,7 @@ def preprocess( segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/seggpt/image_processing_seggpt.py b/src/transformers/models/seggpt/image_processing_seggpt.py index 8ced740d3825..f78536b18867 100644 --- a/src/transformers/models/seggpt/image_processing_seggpt.py +++ b/src/transformers/models/seggpt/image_processing_seggpt.py @@ -324,10 +324,7 @@ def _preprocess_step( images = make_flat_list_of_images(images, expected_ndims=2 if do_convert_rgb else 3) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") if do_resize and size is None: raise ValueError("Size must be specified if do_resize is True.") diff --git a/src/transformers/models/siglip/image_processing_siglip.py b/src/transformers/models/siglip/image_processing_siglip.py index 60ee68e1eba2..39ecb0e15b11 100644 --- a/src/transformers/models/siglip/image_processing_siglip.py +++ b/src/transformers/models/siglip/image_processing_siglip.py @@ -183,10 +183,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/siglip2/image_processing_siglip2.py b/src/transformers/models/siglip2/image_processing_siglip2.py index ef8685b1c312..74a166c440b5 100644 --- a/src/transformers/models/siglip2/image_processing_siglip2.py +++ b/src/transformers/models/siglip2/image_processing_siglip2.py @@ -269,10 +269,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/smolvlm/image_processing_smolvlm.py b/src/transformers/models/smolvlm/image_processing_smolvlm.py index 36f42462540b..8a8ee5d4aa14 100644 --- a/src/transformers/models/smolvlm/image_processing_smolvlm.py +++ b/src/transformers/models/smolvlm/image_processing_smolvlm.py @@ -686,10 +686,7 @@ def preprocess( images_list = make_nested_list_of_images(images) if not valid_images(images_list[0]): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/speech_encoder_decoder/__init__.py b/src/transformers/models/speech_encoder_decoder/__init__.py index 4e07844d45c2..40f66540c963 100644 --- a/src/transformers/models/speech_encoder_decoder/__init__.py +++ b/src/transformers/models/speech_encoder_decoder/__init__.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from .configuration_speech_encoder_decoder import * - from .modeling_flax_speech_encoder_decoder import * from .modeling_speech_encoder_decoder import * else: import sys diff --git a/src/transformers/models/speech_to_text/__init__.py b/src/transformers/models/speech_to_text/__init__.py index ec094769d4ae..b4dce9e2cc61 100644 --- a/src/transformers/models/speech_to_text/__init__.py +++ b/src/transformers/models/speech_to_text/__init__.py @@ -21,7 +21,6 @@ from .configuration_speech_to_text import * from .feature_extraction_speech_to_text import * from .modeling_speech_to_text import * - from .modeling_tf_speech_to_text import * from .processing_speech_to_text import * from .tokenization_speech_to_text import * else: diff --git a/src/transformers/models/superglue/image_processing_superglue.py b/src/transformers/models/superglue/image_processing_superglue.py index b6c79c814b97..ead841c4f176 100644 --- a/src/transformers/models/superglue/image_processing_superglue.py +++ b/src/transformers/models/superglue/image_processing_superglue.py @@ -289,10 +289,7 @@ def preprocess( images = validate_and_format_image_pairs(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_resize=do_resize, diff --git a/src/transformers/models/superpoint/image_processing_superpoint.py b/src/transformers/models/superpoint/image_processing_superpoint.py index 47f2e2e4b473..dc2c6ab22419 100644 --- a/src/transformers/models/superpoint/image_processing_superpoint.py +++ b/src/transformers/models/superpoint/image_processing_superpoint.py @@ -245,10 +245,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") if do_resize and size is None: raise ValueError("Size must be specified if do_resize is True.") diff --git a/src/transformers/models/swiftformer/__init__.py b/src/transformers/models/swiftformer/__init__.py index 370f6c71fadb..b239996fb976 100644 --- a/src/transformers/models/swiftformer/__init__.py +++ b/src/transformers/models/swiftformer/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_swiftformer import * from .modeling_swiftformer import * - from .modeling_tf_swiftformer import * else: import sys diff --git a/src/transformers/models/swin/__init__.py b/src/transformers/models/swin/__init__.py index 3dc5871b0375..bf351e817fdf 100644 --- a/src/transformers/models/swin/__init__.py +++ b/src/transformers/models/swin/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_swin import * from .modeling_swin import * - from .modeling_tf_swin import * else: import sys diff --git a/src/transformers/models/swin2sr/image_processing_swin2sr.py b/src/transformers/models/swin2sr/image_processing_swin2sr.py index def1c6d533ff..b15e7a9d8f86 100644 --- a/src/transformers/models/swin2sr/image_processing_swin2sr.py +++ b/src/transformers/models/swin2sr/image_processing_swin2sr.py @@ -176,10 +176,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/t5/__init__.py b/src/transformers/models/t5/__init__.py index 366eab10826b..cdbf8a9937a7 100644 --- a/src/transformers/models/t5/__init__.py +++ b/src/transformers/models/t5/__init__.py @@ -19,9 +19,7 @@ if TYPE_CHECKING: from .configuration_t5 import * - from .modeling_flax_t5 import * from .modeling_t5 import * - from .modeling_tf_t5 import * from .tokenization_t5 import * from .tokenization_t5_fast import * else: diff --git a/src/transformers/models/t5/convert_t5x_checkpoint_to_flax.py b/src/transformers/models/t5/convert_t5x_checkpoint_to_flax.py deleted file mode 100644 index 12498359d21b..000000000000 --- a/src/transformers/models/t5/convert_t5x_checkpoint_to_flax.py +++ /dev/null @@ -1,235 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Convert T5X checkpoints from the original repository to JAX/FLAX model.""" - -import argparse - -from t5x import checkpoints - -from transformers import FlaxT5ForConditionalGeneration, T5Config - - -def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path): - config = T5Config.from_pretrained(config_name) - flax_model = FlaxT5ForConditionalGeneration(config=config) - t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) - - split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"] - - # Encoder - for layer_index in range(config.num_layers): - layer_name = f"layers_{str(layer_index)}" - - # Self-Attention - t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"] - t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"] - t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"] - t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"] - - # Layer Normalization - t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"] - - if split_mlp_wi: - t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"] - t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"] - else: - t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"] - - t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] - - # Layer Normalization - t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"] - - # Assigning - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( - t5x_attention_key - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( - t5x_attention_out - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( - t5x_attention_query - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( - t5x_attention_value - ) - - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( - t5x_attention_layer_norm - ) - - if split_mlp_wi: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"][ - "kernel" - ] = t5x_mlp_wi_0 - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"][ - "kernel" - ] = t5x_mlp_wi_1 - else: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = ( - t5x_mlp_wi - ) - - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = ( - t5x_mlp_wo - ) - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( - t5x_mlp_layer_norm - ) - - # Only for layer 0: - t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T - flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ - "embedding" - ] = t5x_encoder_rel_embedding - - # Assigning - t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"] - flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm - - # Decoder - for layer_index in range(config.num_decoder_layers): - layer_name = f"layers_{str(layer_index)}" - - # Self-Attention - t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"] - t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"] - t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"] - t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"] - - # Layer Normalization - t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][ - "scale" - ] - - # Encoder-Decoder-Attention - t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"][ - "kernel" - ] - t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"][ - "kernel" - ] - t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"][ - "kernel" - ] - t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"][ - "kernel" - ] - - # Layer Normalization - t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"] - - # MLP - if split_mlp_wi: - t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"] - t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"] - else: - t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"] - - t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] - - # Layer Normalization - tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] - - # Assigning - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( - t5x_attention_key - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( - t5x_attention_out - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( - t5x_attention_query - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( - t5x_attention_value - ) - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( - t5x_pre_attention_layer_norm - ) - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = ( - t5x_enc_dec_attention_key - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = ( - t5x_enc_dec_attention_out - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = ( - t5x_enc_dec_attention_query - ) - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = ( - t5x_enc_dec_attention_value - ) - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( - t5x_cross_layer_norm - ) - - if split_mlp_wi: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"][ - "kernel" - ] = t5x_mlp_wi_0 - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"][ - "kernel" - ] = t5x_mlp_wi_1 - else: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = ( - t5x_mlp_wi - ) - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = ( - t5x_mlp_wo - ) - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = ( - tx5_mlp_layer_norm - ) - - # Decoder Normalization - tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"] - flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm - - # Only for layer 0: - t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T - flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ - "embedding" - ] = t5x_decoder_rel_embedding - - # Token Embeddings - tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"] - flax_model.params["shared"]["embedding"] = tx5_token_embeddings - - # LM Head (only in v1.1 checkpoints) - if "logits_dense" in t5x_model["target"]["decoder"]: - flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"] - - flax_model.save_pretrained(flax_dump_folder_path) - print("T5X Model was successfully converted!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint." - ) - parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of T5 model.") - parser.add_argument( - "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." - ) - args = parser.parse_args() - convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path) diff --git a/src/transformers/models/tapas/__init__.py b/src/transformers/models/tapas/__init__.py index 7df7e765f60e..d85329883381 100644 --- a/src/transformers/models/tapas/__init__.py +++ b/src/transformers/models/tapas/__init__.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from .configuration_tapas import * from .modeling_tapas import * - from .modeling_tf_tapas import * from .tokenization_tapas import * else: import sys diff --git a/src/transformers/models/textnet/image_processing_textnet.py b/src/transformers/models/textnet/image_processing_textnet.py index 7a461d7145c1..578dabd3cb71 100644 --- a/src/transformers/models/textnet/image_processing_textnet.py +++ b/src/transformers/models/textnet/image_processing_textnet.py @@ -291,10 +291,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/tvp/image_processing_tvp.py b/src/transformers/models/tvp/image_processing_tvp.py index 9ef3c93fb13f..12356c082a03 100644 --- a/src/transformers/models/tvp/image_processing_tvp.py +++ b/src/transformers/models/tvp/image_processing_tvp.py @@ -435,10 +435,7 @@ def preprocess( crop_size = get_size_dict(crop_size, param_name="crop_size") if not valid_images(videos): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") videos = make_batched(videos) diff --git a/src/transformers/models/video_llava/image_processing_video_llava.py b/src/transformers/models/video_llava/image_processing_video_llava.py index 0e9141a70dfb..d02ceff80c1e 100644 --- a/src/transformers/models/video_llava/image_processing_video_llava.py +++ b/src/transformers/models/video_llava/image_processing_video_llava.py @@ -259,10 +259,7 @@ def preprocess( images = make_flat_list_of_images(images) if images is not None and not valid_images(images): - raise ValueError( - "Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") data = {} if videos is not None: diff --git a/src/transformers/models/videomae/image_processing_videomae.py b/src/transformers/models/videomae/image_processing_videomae.py index 622ae7ab9c47..b1e3ffd4de91 100644 --- a/src/transformers/models/videomae/image_processing_videomae.py +++ b/src/transformers/models/videomae/image_processing_videomae.py @@ -312,10 +312,7 @@ def preprocess( crop_size = get_size_dict(crop_size, param_name="crop_size") if not valid_images(videos): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") videos = make_batched(videos) diff --git a/src/transformers/models/vilt/image_processing_vilt.py b/src/transformers/models/vilt/image_processing_vilt.py index 28f33cf9017a..c7013e660332 100644 --- a/src/transformers/models/vilt/image_processing_vilt.py +++ b/src/transformers/models/vilt/image_processing_vilt.py @@ -416,10 +416,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") # Here the pad() method does not require any additional argument as it takes the maximum of (height, width). # Hence, it does not need to be passed to a validate_preprocess_arguments() method. diff --git a/src/transformers/models/vision_encoder_decoder/__init__.py b/src/transformers/models/vision_encoder_decoder/__init__.py index 613aae114b33..c2afd574c8df 100644 --- a/src/transformers/models/vision_encoder_decoder/__init__.py +++ b/src/transformers/models/vision_encoder_decoder/__init__.py @@ -19,8 +19,6 @@ if TYPE_CHECKING: from .configuration_vision_encoder_decoder import * - from .modeling_flax_vision_encoder_decoder import * - from .modeling_tf_vision_encoder_decoder import * from .modeling_vision_encoder_decoder import * else: import sys diff --git a/src/transformers/models/vision_text_dual_encoder/__init__.py b/src/transformers/models/vision_text_dual_encoder/__init__.py index 4b68df9c336f..8043c28bcca8 100644 --- a/src/transformers/models/vision_text_dual_encoder/__init__.py +++ b/src/transformers/models/vision_text_dual_encoder/__init__.py @@ -19,8 +19,6 @@ if TYPE_CHECKING: from .configuration_vision_text_dual_encoder import * - from .modeling_flax_vision_text_dual_encoder import * - from .modeling_tf_vision_text_dual_encoder import * from .modeling_vision_text_dual_encoder import * from .processing_vision_text_dual_encoder import * else: diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py index 039f9fa9e9c5..50d0c433cfce 100755 --- a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py @@ -299,10 +299,8 @@ def from_vision_text_pretrained( - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *PyTorch checkpoint folder* (e.g, `./pt_model`). In this case, `from_pt` - should be set to `True` and a configuration object should be provided as `config` argument. This - loading path is slower than converting the PyTorch checkpoint in a Flax model using the provided - conversion scripts and loading the Flax model afterwards. + - A path or url to a *PyTorch checkpoint folder* (e.g, `./pt_model`). In this case, a configuration + object should be provided as `config` argument. text_model_name_or_path (`str`, *optional*): Information necessary to initiate the text model. Can be either: @@ -310,10 +308,8 @@ def from_vision_text_pretrained( - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *PyTorch checkpoint folder* (e.g, `./pt_model`). In this case, `from_pt` - should be set to `True` and a configuration object should be provided as `config` argument. This - loading path is slower than converting the PyTorch checkpoint in a Flax model using the provided - conversion scripts and loading the Flax model afterwards. + - A path or url to a *PyTorch checkpoint folder* (e.g, `./pt_model`). In this case, a configuration + object should be provided as `config` argument. model_args (remaining positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. diff --git a/src/transformers/models/vit/__init__.py b/src/transformers/models/vit/__init__.py index 4d6a7a23fa63..fcb1027b0d6d 100644 --- a/src/transformers/models/vit/__init__.py +++ b/src/transformers/models/vit/__init__.py @@ -22,8 +22,6 @@ from .feature_extraction_vit import * from .image_processing_vit import * from .image_processing_vit_fast import * - from .modeling_flax_vit import * - from .modeling_tf_vit import * from .modeling_vit import * else: import sys diff --git a/src/transformers/models/vit/image_processing_vit.py b/src/transformers/models/vit/image_processing_vit.py index b7e118007ace..645e2616b2ee 100644 --- a/src/transformers/models/vit/image_processing_vit.py +++ b/src/transformers/models/vit/image_processing_vit.py @@ -226,10 +226,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/vit_mae/__init__.py b/src/transformers/models/vit_mae/__init__.py index 253017c39d6a..c64dd050a60b 100644 --- a/src/transformers/models/vit_mae/__init__.py +++ b/src/transformers/models/vit_mae/__init__.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from .configuration_vit_mae import * - from .modeling_tf_vit_mae import * from .modeling_vit_mae import * else: import sys diff --git a/src/transformers/models/vitmatte/image_processing_vitmatte.py b/src/transformers/models/vitmatte/image_processing_vitmatte.py index 6a06df90a30f..87b6d2662ef4 100644 --- a/src/transformers/models/vitmatte/image_processing_vitmatte.py +++ b/src/transformers/models/vitmatte/image_processing_vitmatte.py @@ -214,16 +214,10 @@ def preprocess( trimaps = make_flat_list_of_images(trimaps, expected_ndims=2) if not valid_images(trimaps): - raise ValueError( - "Invalid trimap type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid trimap type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/models/vitpose/image_processing_vitpose.py b/src/transformers/models/vitpose/image_processing_vitpose.py index 1fffc0e389c5..c4a10d35944b 100644 --- a/src/transformers/models/vitpose/image_processing_vitpose.py +++ b/src/transformers/models/vitpose/image_processing_vitpose.py @@ -485,10 +485,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") if isinstance(boxes, list) and len(images) != len(boxes): raise ValueError(f"Batch of images and boxes mismatch : {len(images)} != {len(boxes)}") diff --git a/src/transformers/models/vivit/image_processing_vivit.py b/src/transformers/models/vivit/image_processing_vivit.py index 51655a71cf51..e287e1d608a2 100644 --- a/src/transformers/models/vivit/image_processing_vivit.py +++ b/src/transformers/models/vivit/image_processing_vivit.py @@ -368,10 +368,7 @@ def preprocess( crop_size = get_size_dict(crop_size, param_name="crop_size") if not valid_images(videos): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") videos = make_batched(videos) diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py index 3516b478194d..aa3a5c4c82f8 100644 --- a/src/transformers/models/wav2vec2/__init__.py +++ b/src/transformers/models/wav2vec2/__init__.py @@ -20,8 +20,6 @@ if TYPE_CHECKING: from .configuration_wav2vec2 import * from .feature_extraction_wav2vec2 import * - from .modeling_flax_wav2vec2 import * - from .modeling_tf_wav2vec2 import * from .modeling_wav2vec2 import * from .processing_wav2vec2 import * from .tokenization_wav2vec2 import * diff --git a/src/transformers/models/whisper/__init__.py b/src/transformers/models/whisper/__init__.py index a4956c5fbb25..ec73ac2b8fe9 100644 --- a/src/transformers/models/whisper/__init__.py +++ b/src/transformers/models/whisper/__init__.py @@ -20,8 +20,6 @@ if TYPE_CHECKING: from .configuration_whisper import * from .feature_extraction_whisper import * - from .modeling_flax_whisper import * - from .modeling_tf_whisper import * from .modeling_whisper import * from .processing_whisper import * from .tokenization_whisper import * diff --git a/src/transformers/models/xglm/__init__.py b/src/transformers/models/xglm/__init__.py index 1eefd79d4cf7..363babae7e6c 100644 --- a/src/transformers/models/xglm/__init__.py +++ b/src/transformers/models/xglm/__init__.py @@ -19,8 +19,6 @@ if TYPE_CHECKING: from .configuration_xglm import * - from .modeling_flax_xglm import * - from .modeling_tf_xglm import * from .modeling_xglm import * from .tokenization_xglm import * from .tokenization_xglm_fast import * diff --git a/src/transformers/models/xglm/configuration_xglm.py b/src/transformers/models/xglm/configuration_xglm.py index d8a3be370b7f..eae648c4726a 100644 --- a/src/transformers/models/xglm/configuration_xglm.py +++ b/src/transformers/models/xglm/configuration_xglm.py @@ -35,7 +35,7 @@ class XGLMConfig(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 256008): Vocabulary size of the XGLM model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`XGLMModel`] or [`FlaxXGLMModel`]. + `inputs_ids` passed when calling [`XGLMModel`]. max_position_embeddings (`int`, *optional*, defaults to 2048): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). diff --git a/src/transformers/models/xlm/__init__.py b/src/transformers/models/xlm/__init__.py index 1167fc93a101..d6ad3ff9c90d 100644 --- a/src/transformers/models/xlm/__init__.py +++ b/src/transformers/models/xlm/__init__.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from .configuration_xlm import * - from .modeling_tf_xlm import * from .modeling_xlm import * from .tokenization_xlm import * else: diff --git a/src/transformers/models/xlm_roberta/__init__.py b/src/transformers/models/xlm_roberta/__init__.py index 0e684c6c9b2c..1706e6dbefae 100644 --- a/src/transformers/models/xlm_roberta/__init__.py +++ b/src/transformers/models/xlm_roberta/__init__.py @@ -19,8 +19,6 @@ if TYPE_CHECKING: from .configuration_xlm_roberta import * - from .modeling_flax_xlm_roberta import * - from .modeling_tf_xlm_roberta import * from .modeling_xlm_roberta import * from .tokenization_xlm_roberta import * from .tokenization_xlm_roberta_fast import * diff --git a/src/transformers/models/xlnet/__init__.py b/src/transformers/models/xlnet/__init__.py index 3f4534559253..73fe8d46985c 100644 --- a/src/transformers/models/xlnet/__init__.py +++ b/src/transformers/models/xlnet/__init__.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from .configuration_xlnet import * - from .modeling_tf_xlnet import * from .modeling_xlnet import * from .tokenization_xlnet import * from .tokenization_xlnet_fast import * diff --git a/src/transformers/models/zoedepth/image_processing_zoedepth.py b/src/transformers/models/zoedepth/image_processing_zoedepth.py index 85da72325232..1ef2b8a59ec1 100644 --- a/src/transformers/models/zoedepth/image_processing_zoedepth.py +++ b/src/transformers/models/zoedepth/image_processing_zoedepth.py @@ -386,10 +386,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index e9ef235af087..2f64aa416309 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -230,6 +230,7 @@ def load_model( """ if not is_torch_available(): raise RuntimeError("PyTorch should be installed. Please follow the instructions at https://pytorch.org/.") + if isinstance(model, str): model_kwargs["_from_pipeline"] = task class_tuple = model_classes if model_classes is not None else (AutoModel,) @@ -259,7 +260,7 @@ def load_model( # we can transparently retry the load in float32 before surfacing an error to the user. fallback_tried = False if "dtype" in kwargs: - import torch # local import to avoid unnecessarily importing torch for TF/JAX users + import torch fallback_tried = True fp32_kwargs = kwargs.copy() diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index ee1bb3f4f4ca..70eed29f3a65 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -275,11 +275,6 @@ WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" -TF2_WEIGHTS_NAME = "tf_model.h5" -TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" -TF_WEIGHTS_NAME = "model.ckpt" -FLAX_WEIGHTS_NAME = "flax_model.msgpack" -FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json" SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" CONFIG_NAME = "config.json" diff --git a/tests/models/aya_vision/test_modeling_aya_vision.py b/tests/models/aya_vision/test_modeling_aya_vision.py index 436cba19c290..8911d39ec10c 100644 --- a/tests/models/aya_vision/test_modeling_aya_vision.py +++ b/tests/models/aya_vision/test_modeling_aya_vision.py @@ -198,7 +198,7 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip uses a non-standard initialization scheme") def test_initialization(self): pass diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index b352b8160468..0217c5914300 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -754,7 +754,7 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="CLIP uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="CLIP uses a non-standard initialization scheme") def test_initialization(self): pass diff --git a/tests/models/cohere2_vision/test_modeling_cohere2_vision.py b/tests/models/cohere2_vision/test_modeling_cohere2_vision.py index 96843faa95f7..7a12c2ad9fca 100644 --- a/tests/models/cohere2_vision/test_modeling_cohere2_vision.py +++ b/tests/models/cohere2_vision/test_modeling_cohere2_vision.py @@ -170,7 +170,7 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() - @unittest.skip(reason="Siglip backbone uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip backbone uses a non-standard initialization scheme") def test_initialization(self): pass diff --git a/tests/models/colpali/test_modeling_colpali.py b/tests/models/colpali/test_modeling_colpali.py index 7966e34ce323..e051e431bfa8 100644 --- a/tests/models/colpali/test_modeling_colpali.py +++ b/tests/models/colpali/test_modeling_colpali.py @@ -272,9 +272,7 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_model_parallelism(self): pass - @unittest.skip( - reason="PaliGemma's SigLip encoder uses the same initialization scheme as the Flax original implementation" - ) + @unittest.skip(reason="PaliGemma's SigLip encoder uses a non-standard initialization scheme") def test_initialization(self): pass diff --git a/tests/models/deepseek_vl/test_modeling_deepseek_vl.py b/tests/models/deepseek_vl/test_modeling_deepseek_vl.py index a2d1950dcdc4..55ced08a09d3 100644 --- a/tests/models/deepseek_vl/test_modeling_deepseek_vl.py +++ b/tests/models/deepseek_vl/test_modeling_deepseek_vl.py @@ -187,7 +187,7 @@ def test_inputs_embeds_matches_input_ids(self): out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] torch.testing.assert_close(out_embeds, out_ids) - @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip uses a non-standard initialization scheme") # Copied from tests.models.siglip.test_modeling_siglip.SiglipVisionModelTest.test_initialization def test_initialization(self): pass diff --git a/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py b/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py index fbb904da735b..02a934275012 100644 --- a/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py +++ b/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py @@ -218,7 +218,7 @@ def test_inputs_embeds_matches_input_ids(self): out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] torch.testing.assert_close(out_embeds, out_ids) - @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip uses a non-standard initialization scheme") # Copied from tests.models.siglip.test_modeling_siglip.SiglipVisionModelTest.test_initialization def test_initialization(self): pass diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py index 5e4b774a8bd0..eca8cfdc56ee 100644 --- a/tests/models/gemma3n/test_modeling_gemma3n.py +++ b/tests/models/gemma3n/test_modeling_gemma3n.py @@ -714,9 +714,7 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip( - reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation" - ) + @unittest.skip(reason="Siglip (vision backbone) uses a non-standard initialization scheme") def test_initialization(self): pass diff --git a/tests/models/metaclip_2/test_modeling_metaclip_2.py b/tests/models/metaclip_2/test_modeling_metaclip_2.py index f8ad7701eab3..19823ba4ac73 100644 --- a/tests/models/metaclip_2/test_modeling_metaclip_2.py +++ b/tests/models/metaclip_2/test_modeling_metaclip_2.py @@ -765,7 +765,7 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="MetaClip2 uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="MetaClip2 uses a non-standard initialization scheme") def test_initialization(self): pass diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index d130122b16ff..21b9b8a4711e 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -264,9 +264,7 @@ def test_disk_offload_safetensors(self): def test_model_parallelism(self): pass - @unittest.skip( - reason="PaliGemma's SigLip encoder uses the same initialization scheme as the Flax original implementation" - ) + @unittest.skip(reason="PaliGemma's SigLip encoder uses a non-standard initialization scheme") def test_initialization(self): pass diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py index ad345e70e03e..a33f03194f8a 100644 --- a/tests/models/paligemma2/test_modeling_paligemma2.py +++ b/tests/models/paligemma2/test_modeling_paligemma2.py @@ -247,9 +247,7 @@ def test_disk_offload_safetensors(self): def test_model_parallelism(self): pass - @unittest.skip( - reason="PaliGemma's SigLip encoder uses the same initialization scheme as the Flax original implementation" - ) + @unittest.skip(reason="PaliGemma's SigLip encoder uses a non-standard initialization scheme") def test_initialization(self): pass diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py index a4c829493b17..0005c44e634a 100644 --- a/tests/models/siglip/test_modeling_siglip.py +++ b/tests/models/siglip/test_modeling_siglip.py @@ -240,7 +240,7 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip uses a non-standard initialization scheme") def test_initialization(self): pass @@ -386,7 +386,7 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip uses a non-standard initialization scheme") def test_initialization(self): pass @@ -498,7 +498,7 @@ def test_retain_grad_hidden_states_attentions(self): def test_model_get_set_embeddings(self): pass - @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip uses a non-standard initialization scheme") def test_initialization(self): pass @@ -658,7 +658,7 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip uses a non-standard initialization scheme") def test_initialization(self): pass diff --git a/tests/models/siglip2/test_modeling_siglip2.py b/tests/models/siglip2/test_modeling_siglip2.py index e7147e6055aa..d6054dd8d15d 100644 --- a/tests/models/siglip2/test_modeling_siglip2.py +++ b/tests/models/siglip2/test_modeling_siglip2.py @@ -332,7 +332,7 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip2 uses a non-standard initialization scheme") def test_initialization(self): pass @@ -474,7 +474,7 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip2 uses a non-standard initialization scheme") def test_initialization(self): pass @@ -591,7 +591,7 @@ def test_retain_grad_hidden_states_attentions(self): def test_model_get_set_embeddings(self): pass - @unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip2 uses a non-standard initialization scheme") def test_initialization(self): pass @@ -689,7 +689,7 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation") + @unittest.skip(reason="Siglip2 uses a non-standard initialization scheme") def test_initialization(self): pass diff --git a/tests/utils/test_hub_utils.py b/tests/utils/test_hub_utils.py index b86773793a84..df3af3d97a83 100644 --- a/tests/utils/test_hub_utils.py +++ b/tests/utils/test_hub_utils.py @@ -24,8 +24,6 @@ from transformers.utils import ( CONFIG_NAME, - FLAX_WEIGHTS_NAME, - TF2_WEIGHTS_NAME, TRANSFORMERS_CACHE, WEIGHTS_NAME, cached_file, @@ -97,8 +95,8 @@ def test_non_existence_is_cached(self): def test_has_file(self): self.assertTrue(has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME)) - self.assertFalse(has_file(TINY_BERT_PT_ONLY, TF2_WEIGHTS_NAME)) - self.assertFalse(has_file(TINY_BERT_PT_ONLY, FLAX_WEIGHTS_NAME)) + self.assertFalse(has_file(TINY_BERT_PT_ONLY, "tf_model.h5")) + self.assertFalse(has_file(TINY_BERT_PT_ONLY, "flax_model.msgpack")) def test_has_file_in_cache(self): with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/utils/add_pipeline_model_mapping_to_test.py b/utils/add_pipeline_model_mapping_to_test.py index 9e261da03d14..37723bf0bb9c 100644 --- a/utils/add_pipeline_model_mapping_to_test.py +++ b/utils/add_pipeline_model_mapping_to_test.py @@ -117,10 +117,9 @@ def get_pipeline_model_mapping_string(test_class): def is_valid_test_class(test_class): """Restrict to `XXXModelTesterMixin` and should be a subclass of `unittest.TestCase`.""" - base_class_names = {"ModelTesterMixin", "TFModelTesterMixin", "FlaxModelTesterMixin"} if not issubclass(test_class, unittest.TestCase): return False - return len(base_class_names.intersection([x.__name__ for x in test_class.__bases__])) > 0 + return "ModelTesterMixin" in [x.__name__ for x in test_class.__bases__] def find_test_class(test_file): @@ -300,9 +299,7 @@ def add_pipeline_model_mapping_to_test_file(test_file, overwrite=False): else: pattern = os.path.join("tests", "models", "**", "test_modeling_*.py") for test_file in glob.glob(pattern): - # `Flax` is not concerned at this moment - if not test_file.startswith("test_modeling_flax_"): - test_files.append(test_file) + test_files.append(test_file) for test_file in test_files: if test_file in TEST_FILE_TO_IGNORE: diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 9eeda74afa48..754a86941d93 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -465,8 +465,6 @@ # below, make sure to add a comment explaining why. OBJECT_TO_IGNORE_PREFIXES = [ "_", # Private objects are not documented - "TF", # TensorFlow objects are scheduled to be removed in the future - "Flax", # Flax objects are scheduled to be removed in the future ] # Supported math operations when interpreting the value of defaults. @@ -923,14 +921,10 @@ def find_matching_model_files(check_all: bool = False): potential_files = glob.glob(modeling_glob_pattern) image_processing_glob_pattern = os.path.join(PATH_TO_TRANSFORMERS, "models/**/image_processing_*_fast.py") potential_files += glob.glob(image_processing_glob_pattern) - exclude_substrings = ["modeling_tf_", "modeling_flax_"] matching_files = [] for file_path in potential_files: if os.path.isfile(file_path): - filename = os.path.basename(file_path) - is_excluded = any(exclude in filename for exclude in exclude_substrings) - if not is_excluded: - matching_files.append(file_path) + matching_files.append(file_path) if not check_all: # intersect with module_diff_files matching_files = sorted([file for file in matching_files if file in module_diff_files]) diff --git a/utils/check_inits.py b/utils/check_inits.py index bc211baa9edc..e90990ac607a 100644 --- a/utils/check_inits.py +++ b/utils/check_inits.py @@ -308,7 +308,6 @@ def get_transformers_submodules() -> list[str]: IGNORE_SUBMODULES = [ "convert_pytorch_checkpoint_to_tf2", - "modeling_flax_pytorch_utils", "models.esm.openfold_utils", "modeling_attn_mask_utils", "safetensors_conversion", diff --git a/utils/check_model_tester.py b/utils/check_model_tester.py index 8ace411b1a4e..60f7c0d32c3d 100644 --- a/utils/check_model_tester.py +++ b/utils/check_model_tester.py @@ -25,10 +25,6 @@ pattern = os.path.join("tests", "models", "**", "test_modeling_*.py") test_files = glob.glob(pattern) - # TODO: deal with TF/Flax too - test_files = [ - x for x in test_files if not (x.startswith("test_modeling_tf_") or x.startswith("test_modeling_flax_")) - ] for test_file in test_files: tester_classes = get_tester_classes(test_file) diff --git a/utils/get_test_info.py b/utils/get_test_info.py index 3c376bdbdaaf..d3be2792c3d9 100644 --- a/utils/get_test_info.py +++ b/utils/get_test_info.py @@ -81,15 +81,14 @@ def get_tester_classes(test_file): def get_test_classes(test_file): """Get all [test] classes in a model test file with attribute `all_model_classes` that are non-empty. - These are usually the (model) test classes containing the (non-slow) tests to run and are subclasses of one of the - classes `ModelTesterMixin`, `TFModelTesterMixin` or `FlaxModelTesterMixin`, as well as a subclass of - `unittest.TestCase`. Exceptions include `RagTestMixin` (and its subclasses). + These are usually the (model) test classes containing the (non-slow) tests to run and are subclasses of + `ModelTesterMixin`, as well as a subclass of `unittest.TestCase`. Exceptions include `RagTestMixin` (and its subclasses). """ test_classes = [] test_module = get_test_module(test_file) for attr in dir(test_module): attr_value = getattr(test_module, attr) - # (TF/Flax)ModelTesterMixin is also an attribute in specific model test module. Let's exclude them by checking + # ModelTesterMixin is also an attribute in specific model test module. Let's exclude them by checking # `all_model_classes` is not empty (which also excludes other special classes). model_classes = getattr(attr_value, "all_model_classes", []) if len(model_classes) > 0: @@ -118,7 +117,7 @@ def get_model_tester_from_test_class(test_class): model_tester = None if hasattr(test, "model_tester"): - # `(TF/Flax)ModelTesterMixin` has this attribute default to `None`. Let's skip this case. + # `ModelTesterMixin` has this attribute default to `None`. Let's skip this case. if test.model_tester is not None: model_tester = test.model_tester.__class__ diff --git a/utils/models_to_deprecate.py b/utils/models_to_deprecate.py index 17ea1fd28ec8..a92e1019cd1a 100644 --- a/utils/models_to_deprecate.py +++ b/utils/models_to_deprecate.py @@ -61,10 +61,6 @@ def get_list_of_repo_model_paths(models_dir): # Get list of all models in the library models = glob.glob(os.path.join(models_dir, "*/modeling_*.py")) - # Remove flax and tf models - models = [model for model in models if "_flax_" not in model] - models = [model for model in models if "_tf_" not in model] - # Get list of all deprecated models in the library deprecated_models = glob.glob(os.path.join(models_dir, "deprecated", "*")) # For each deprecated model, remove the deprecated models from the list of all models as well as the symlink path diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index 942313d57bb2..f2021bbca0ba 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -536,7 +536,6 @@ src/transformers/models/llava_next/modeling_llava_next.py src/transformers/models/longformer/configuration_longformer.py src/transformers/models/longformer/convert_longformer_original_pytorch_lightning_to_pytorch.py src/transformers/models/longt5/configuration_longt5.py -src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py src/transformers/models/luke/configuration_luke.py src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py src/transformers/models/luke/modeling_luke.py @@ -677,7 +676,6 @@ src/transformers/models/switch_transformers/configuration_switch_transformers.py src/transformers/models/switch_transformers/convert_big_switch.py src/transformers/models/switch_transformers/modeling_switch_transformers.py src/transformers/models/t5/configuration_t5.py -src/transformers/models/t5/convert_t5x_checkpoint_to_flax.py src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py src/transformers/models/t5/modeling_t5.py src/transformers/models/table_transformer/configuration_table_transformer.py diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index f0ecb7842cd1..49821d703890 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -1092,15 +1092,15 @@ def parse_commit_message(commit_message: str) -> dict[str, bool]: JOB_TO_TEST_FILE = { - "tests_torch": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*", - "tests_generate": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*", + "tests_torch": r"tests/models/.*/test_modeling_.*", + "tests_generate": r"tests/models/.*/test_modeling_.*", "tests_tokenization": r"tests/(?:models/.*/test_tokenization.*|test_tokenization_mistral_common\.py)", "tests_processors": r"tests/models/.*/test_(?!(?:modeling_|tokenization_)).*", # takes feature extractors, image processors, processors "examples_torch": r"examples/pytorch/.*test_.*", "tests_exotic_models": r"tests/models/.*(?=layoutlmv|nat|deta|udop|nougat).*", "tests_custom_tokenizers": r"tests/models/.*/test_tokenization_(?=bert_japanese|openai|clip).*", # "repo_utils": r"tests/[^models].*test.*", TODO later on we might want to do - "pipelines_torch": r"tests/models/.*/test_modeling_(?!(?:flax_|tf_)).*", + "pipelines_torch": r"tests/models/.*/test_modeling_.*", "tests_hub": r"tests/.*", "tests_non_model": r"tests/[^/]*?/test_.*\.py", } diff --git a/utils/update_tiny_models.py b/utils/update_tiny_models.py index c7d62f8e94b8..d5cb048ea639 100644 --- a/utils/update_tiny_models.py +++ b/utils/update_tiny_models.py @@ -36,18 +36,12 @@ def get_all_model_names(): model_names = set() - # Each auto modeling files contains multiple mappings. Let's get them in a dynamic way. - for module_name in ["modeling_auto", "modeling_tf_auto", "modeling_flax_auto"]: - module = getattr(transformers.models.auto, module_name, None) - if module is None: - continue + + module_name = "modeling_auto" + module = getattr(transformers.models.auto, module_name, None) + if module is not None: # all mappings in a single auto modeling file - mapping_names = [ - x - for x in dir(module) - if x.endswith("_MAPPING_NAMES") - and (x.startswith("MODEL_") or x.startswith("TF_MODEL_") or x.startswith("FLAX_MODEL_")) - ] + mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES") and x.startswith("MODEL_")] for name in mapping_names: mapping = getattr(module, name) if mapping is not None: From dd22eeb763701d248c46e189927c12fd2a528933 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 00:31:11 +0200 Subject: [PATCH 15/35] should fix doc --- docs/source/ar/_toctree.yml | 2 -- docs/source/en/_toctree.yml | 2 -- .../source/en/main_classes/keras_callbacks.md | 28 ------------------- docs/source/ja/_toctree.yml | 2 -- .../source/ja/main_classes/keras_callbacks.md | 28 ------------------- docs/source/ko/_toctree.yml | 2 -- .../source/ko/main_classes/keras_callbacks.md | 27 ------------------ docs/source/ms/_toctree.yml | 2 -- docs/source/zh/_toctree.yml | 2 -- .../source/zh/main_classes/keras_callbacks.md | 27 ------------------ utils/not_doctested.txt | 1 - 11 files changed, 123 deletions(-) delete mode 100644 docs/source/en/main_classes/keras_callbacks.md delete mode 100644 docs/source/ja/main_classes/keras_callbacks.md delete mode 100644 docs/source/ko/main_classes/keras_callbacks.md delete mode 100644 docs/source/zh/main_classes/keras_callbacks.md diff --git a/docs/source/ar/_toctree.yml b/docs/source/ar/_toctree.yml index 8cd45939b3b6..2ac585afadfa 100644 --- a/docs/source/ar/_toctree.yml +++ b/docs/source/ar/_toctree.yml @@ -254,8 +254,6 @@ # title: التكوين # - local: main_classes/data_collator # title: مجمع البيانات -# - local: main_classes/keras_callbacks -# title: استدعاءات Keras # - local: main_classes/logging # title: التسجيل # - local: main_classes/model diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0a64dc03510f..61fea5a26ae7 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -334,8 +334,6 @@ title: Configuration - local: main_classes/data_collator title: Data Collator - - local: main_classes/keras_callbacks - title: Keras callbacks - local: main_classes/logging title: Logging - local: main_classes/model diff --git a/docs/source/en/main_classes/keras_callbacks.md b/docs/source/en/main_classes/keras_callbacks.md deleted file mode 100644 index c9932300dbc5..000000000000 --- a/docs/source/en/main_classes/keras_callbacks.md +++ /dev/null @@ -1,28 +0,0 @@ - - -# Keras callbacks - -When training a Transformers model with Keras, there are some library-specific callbacks available to automate common -tasks: - -## KerasMetricCallback - -[[autodoc]] KerasMetricCallback - -## PushToHubCallback - -[[autodoc]] PushToHubCallback diff --git a/docs/source/ja/_toctree.yml b/docs/source/ja/_toctree.yml index 3f74f498fbcf..d01cf584ecff 100644 --- a/docs/source/ja/_toctree.yml +++ b/docs/source/ja/_toctree.yml @@ -196,8 +196,6 @@ title: 構成 - local: main_classes/data_collator title: データ照合者 - - local: main_classes/keras_callbacks - title: Keras コールバック - local: main_classes/logging title: ロギング - local: main_classes/model diff --git a/docs/source/ja/main_classes/keras_callbacks.md b/docs/source/ja/main_classes/keras_callbacks.md deleted file mode 100644 index ff28107a4345..000000000000 --- a/docs/source/ja/main_classes/keras_callbacks.md +++ /dev/null @@ -1,28 +0,0 @@ - - -# Keras callbacks - -Keras を使用して Transformers モデルをトレーニングする場合、一般的な処理を自動化するために使用できるライブラリ固有のコールバックがいくつかあります。 -タスク: - -## KerasMetricCallback - -[[autodoc]] KerasMetricCallback - -## PushToHubCallback - -[[autodoc]] PushToHubCallback diff --git a/docs/source/ko/_toctree.yml b/docs/source/ko/_toctree.yml index afc0bcf4fa35..21f26cd66af6 100644 --- a/docs/source/ko/_toctree.yml +++ b/docs/source/ko/_toctree.yml @@ -400,8 +400,6 @@ title: Configuration - local: main_classes/data_collator title: Data Collator - - local: main_classes/keras_callbacks - title: Keras callbacks - local: main_classes/logging title: Logging - local: main_classes/model diff --git a/docs/source/ko/main_classes/keras_callbacks.md b/docs/source/ko/main_classes/keras_callbacks.md deleted file mode 100644 index 25d5ea3e4008..000000000000 --- a/docs/source/ko/main_classes/keras_callbacks.md +++ /dev/null @@ -1,27 +0,0 @@ - - -# 케라스 콜백[[keras-callbacks]] - -케라스로 트랜스포머 모델을 학습할 때, 일반적인 작업을 자동화하기 위한 라이브러리 전용 콜백들을 사용 할 수 있습니다. - -## KerasMetricCallback[[transformers.KerasMetricCallback]] - -[[autodoc]] KerasMetricCallback - -## PushToHubCallback[[transformers.PushToHubCallback]] - -[[autodoc]] PushToHubCallback diff --git a/docs/source/ms/_toctree.yml b/docs/source/ms/_toctree.yml index f57a5bab78e9..05d4829437b9 100644 --- a/docs/source/ms/_toctree.yml +++ b/docs/source/ms/_toctree.yml @@ -181,8 +181,6 @@ title: Configuration - local: main_classes/data_collator title: Data Collator - - local: main_classes/keras_callbacks - title: Keras callbacks - local: main_classes/logging title: Logging - local: main_classes/model diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index 0773829e9a6a..ad7e2479b42e 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -103,8 +103,6 @@ title: Configuration - local: main_classes/data_collator title: Data Collator - - local: main_classes/keras_callbacks - title: Keras callbacks - local: main_classes/logging title: Logging - local: main_classes/model diff --git a/docs/source/zh/main_classes/keras_callbacks.md b/docs/source/zh/main_classes/keras_callbacks.md deleted file mode 100644 index 1eea2eb99816..000000000000 --- a/docs/source/zh/main_classes/keras_callbacks.md +++ /dev/null @@ -1,27 +0,0 @@ - - -# Keras callbacks - -在Keras中训练Transformers模型时,有一些库特定的callbacks函数可用于自动执行常见任务: - -## KerasMetricCallback - -[[autodoc]] KerasMetricCallback - -## PushToHubCallback - -[[autodoc]] PushToHubCallback diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index f2021bbca0ba..67015ac0c90c 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -26,7 +26,6 @@ docs/source/en/main_classes/data_collator.md docs/source/en/main_classes/deepspeed.md docs/source/en/main_classes/feature_extractor.md docs/source/en/main_classes/image_processor.md -docs/source/en/main_classes/keras_callbacks.md docs/source/en/main_classes/logging.md docs/source/en/main_classes/model.md docs/source/en/main_classes/onnx.md From bd0c169d0973acc8439373e1f97d63a42082d2c8 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 00:35:32 +0200 Subject: [PATCH 16/35] fic --- src/transformers/onnx/features.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 5a2180798c22..1c57c68e8c87 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -578,7 +578,7 @@ def get_model_class_for_feature(feature: str) -> type: return task_to_automodel[task] @staticmethod - def get_model_from_feature(feature: str, model: str, cache_dir: Optional[str] = None) -> PreTrainedModel: + def get_model_from_feature(feature: str, model: str, cache_dir: Optional[str] = None) -> "PreTrainedModel": """ Attempts to retrieve a model from a model's name and the feature to be enabled. @@ -597,7 +597,7 @@ def get_model_from_feature(feature: str, model: str, cache_dir: Optional[str] = return model @staticmethod - def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> tuple[str, Callable]: + def check_supported_model_or_raise(model: "PreTrainedModel", feature: str = "default") -> tuple[str, Callable]: """ Check whether or not the model has the requested features. From 2b9588dda66dd462badc5240123f637773e498ee Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 12:54:41 +0200 Subject: [PATCH 17/35] fix --- src/transformers/data/data_collator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index c9cdbdce97f4..1bff72cf338c 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -40,7 +40,7 @@ class DataCollatorMixin: def __call__(self, features, return_tensors=None): if return_tensors is None: return_tensors = self.return_tensors - elif return_tensors == "pt": + if return_tensors == "pt": return self.torch_call(features) elif return_tensors == "np": return self.numpy_call(features) From d58b6794cde3aaf5ac470b944c61a9fcd274fd88 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 13:19:11 +0200 Subject: [PATCH 18/35] fix --- src/transformers/modeling_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8a0270567a29..31783d041fe4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -990,7 +990,7 @@ def _get_resolved_checkpoint_files( subfolder: str, variant: Optional[str], gguf_file: Optional[str], - use_safetensors: bool, + use_safetensors: Optional[bool], cache_dir: str, force_download: bool, proxies: Optional[dict[str, str]], @@ -1016,14 +1016,14 @@ def _get_resolved_checkpoint_files( # If the filename is explicitly defined, load this by default. archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename) is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json") - elif use_safetensors and os.path.isfile( + elif use_safetensors is not False and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) ): # Load from a safetensors checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) ) - elif use_safetensors and os.path.isfile( + elif use_safetensors is not False and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) ): # Load from a sharded safetensors checkpoint @@ -1067,7 +1067,7 @@ def _get_resolved_checkpoint_files( if transformers_explicit_filename is not None: filename = transformers_explicit_filename is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json") - elif use_safetensors: + elif use_safetensors is not False: filename = _add_variant(SAFE_WEIGHTS_NAME, variant) else: filename = _add_variant(WEIGHTS_NAME, variant) @@ -4420,7 +4420,7 @@ def from_pretrained( local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", - use_safetensors: bool = True, + use_safetensors: Optional[bool] = None, weights_only: bool = True, **kwargs, ) -> SpecificPreTrainedModelType: @@ -4588,9 +4588,9 @@ def from_pretrained( specify the folder name here. variant (`str`, *optional*): If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. - use_safetensors (`bool`, *optional*, defaults to `True`): - Whether or not to use `safetensors` checkpoints. Defaults to `True`. If `safetensors` is not installed, - it will be set to `False`. + use_safetensors (`bool`, *optional*, defaults to `None`): + Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` + is not installed, it will be set to `False`. weights_only (`bool`, *optional*, defaults to `True`): Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals(). @@ -4741,7 +4741,7 @@ def from_pretrained( if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs: adapter_kwargs["token"] = token - if use_safetensors and not is_safetensors_available(): + if use_safetensors is None and not is_safetensors_available(): use_safetensors = False if gguf_file is not None and not is_accelerate_available(): From 15e9e8da4fb0fdf3550110f9bca2aa57a3bc24dd Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 14:16:49 +0200 Subject: [PATCH 19/35] fix tests --- src/transformers/pipelines/base.py | 3 +-- src/transformers/utils/generic.py | 2 +- tests/utils/test_import_structure.py | 4 +--- tests/utils/test_modeling_utils.py | 4 ++-- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 2f64aa416309..71450199580c 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -61,7 +61,6 @@ from torch.utils.data import DataLoader, Dataset from ..modeling_utils import PreTrainedModel - from ..models.auto.modeling_auto import AutoModel # Re-export for backward compatibility from .pt_utils import KeyDataset @@ -233,7 +232,7 @@ def load_model( if isinstance(model, str): model_kwargs["_from_pipeline"] = task - class_tuple = model_classes if model_classes is not None else (AutoModel,) + class_tuple = model_classes if model_classes is not None else () if config.architectures: classes = [] for architecture in config.architectures: diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index cc2f8b5f7046..ef5e356bcd1c 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -544,7 +544,7 @@ def squeeze(array, axis=None): if is_numpy_array(array): return np.squeeze(array, axis=axis) elif is_torch_tensor(array): - return array.squeeze(dim=axis) + return array.squeeze() if axis is None else array.squeeze(dim=axis) else: raise ValueError(f"Type not supported for squeeze: {type(array)}.") diff --git a/tests/utils/test_import_structure.py b/tests/utils/test_import_structure.py index d69f4f0df1e4..6c293b161e8a 100644 --- a/tests/utils/test_import_structure.py +++ b/tests/utils/test_import_structure.py @@ -52,11 +52,9 @@ def test_definition(self): "import_structure_register_with_comments": {"B0", "b0"}, }, frozenset({"random_item_that_should_not_exist"}): {"failing_export": {"A0"}}, - frozenset({"torch"}): { - "import_structure_register_with_duplicates": {"C0", "C1", "C2", "C3", "c0", "c1", "c2", "c3"} - }, frozenset({"torch"}): { "import_structure_raw_register": {"A1", "A2", "A3", "a1", "a2", "a3"}, + "import_structure_register_with_duplicates": {"C0", "C1", "C2", "C3", "c0", "c1", "c2", "c3"}, "import_structure_register_with_comments": {"B1", "B2", "B3", "b1", "b2", "b3"}, }, frozenset({"torch>=2.5"}): {"import_structure_raw_register_with_versions": {"D0", "d0"}}, diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index be55cc563300..bf6889338b0e 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1290,7 +1290,7 @@ def test_use_safetensors(self): BertModel.from_pretrained("hf-internal-testing/config-no-model") self.assertTrue( - "does not appear to have a file named pytorch_model.bin, model.safetensors," + "does not appear to have a file named pytorch_model.bin or model.safetensors." in str(missing_model_file_error.exception) ) @@ -1302,7 +1302,7 @@ def test_use_safetensors(self): BertModel.from_pretrained(tmp_dir) self.assertTrue( - "Error no file named pytorch_model.bin, model.safetensors" in str(missing_model_file_error.exception) + "Error no file named model.safetensors, or pytorch_model.bin" in str(missing_model_file_error.exception) ) @require_safetensors From de27613f8ade5576b7f4f0d79698366bb34c1b1f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 14:19:36 +0200 Subject: [PATCH 20/35] still tests --- tests/models/auto/test_modeling_auto.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 9d6e9569a9dc..352df1fe7b58 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -515,11 +515,15 @@ def test_model_file_not_found(self): _ = AutoModel.from_pretrained("hf-internal-testing/config-no-model") def test_model_from_tf_error(self): - with self.assertRaisesRegex(EnvironmentError, "Can't load the model for"): + with self.assertRaisesRegex( + EnvironmentError, "does not appear to have a file named pytorch_model.bin or model.safetensors." + ): _ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only") def test_model_from_flax_error(self): - with self.assertRaisesRegex(EnvironmentError, "Can't load the model for"): + with self.assertRaisesRegex( + EnvironmentError, "does not appear to have a file named pytorch_model.bin or model.safetensors." + ): _ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") @unittest.skip("Failing on main") From 6e434e03649377acab9a9823fca5b2fa19215917 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 14:52:50 +0200 Subject: [PATCH 21/35] fix non-deterministic --- tests/utils/test_import_structure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_import_structure.py b/tests/utils/test_import_structure.py index 6c293b161e8a..9c3222d70c50 100644 --- a/tests/utils/test_import_structure.py +++ b/tests/utils/test_import_structure.py @@ -71,7 +71,7 @@ def test_definition(self): self.assertEqual(len(import_structure.keys()), len(valid_frozensets.keys())) for _frozenset in valid_frozensets: self.assertTrue(_frozenset in import_structure) - self.assertListEqual(list(import_structure[_frozenset].keys()), list(valid_frozensets[_frozenset].keys())) + self.assertListEqual(sorted(import_structure[_frozenset].keys()), sorted(valid_frozensets[_frozenset].keys())) for module, objects in valid_frozensets[_frozenset].items(): self.assertTrue(module in import_structure[_frozenset]) self.assertSetEqual(objects, import_structure[_frozenset][module]) From dc65cae15ed058e2afb5a0314a2067f6057914f4 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 14:57:43 +0200 Subject: [PATCH 22/35] style --- tests/utils/test_import_structure.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_import_structure.py b/tests/utils/test_import_structure.py index 9c3222d70c50..6d40cedaea14 100644 --- a/tests/utils/test_import_structure.py +++ b/tests/utils/test_import_structure.py @@ -71,7 +71,9 @@ def test_definition(self): self.assertEqual(len(import_structure.keys()), len(valid_frozensets.keys())) for _frozenset in valid_frozensets: self.assertTrue(_frozenset in import_structure) - self.assertListEqual(sorted(import_structure[_frozenset].keys()), sorted(valid_frozensets[_frozenset].keys())) + self.assertListEqual( + sorted(import_structure[_frozenset].keys()), sorted(valid_frozensets[_frozenset].keys()) + ) for module, objects in valid_frozensets[_frozenset].items(): self.assertTrue(module in import_structure[_frozenset]) self.assertSetEqual(objects, import_structure[_frozenset][module]) From 6c231c6826bacc337da6d016c48c6409013d1815 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 15:27:28 +0200 Subject: [PATCH 23/35] remove last rebase issues --- tests/models/tapas/test_tokenization_tapas.py | 16 ++-------------- tests/test_tokenization_common.py | 16 +--------------- 2 files changed, 3 insertions(+), 29 deletions(-) diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py index 86c6f6c2124e..1f95919e27db 100644 --- a/tests/models/tapas/test_tokenization_tapas.py +++ b/tests/models/tapas/test_tokenization_tapas.py @@ -21,7 +21,7 @@ import pandas as pd from parameterized import parameterized -from transformers import AddedToken, is_flax_available, is_mlx_available, is_tf_available, is_torch_available +from transformers import AddedToken, is_mlx_available, is_torch_available from transformers.models.tapas.tokenization_tapas import ( VOCAB_FILES_NAMES, BasicTokenizer, @@ -1184,18 +1184,6 @@ def test_empty_input_string(self): tokenizer_return_type.append("np") output_tensor_type.append(np.int64) - if is_tf_available(): - import tensorflow as tf - - tokenizer_return_type.append("tf") - output_tensor_type.append(tf.int32) - - if is_flax_available(): - import jax.numpy as jnp - - tokenizer_return_type.append("jax") - output_tensor_type.append(jnp.int32) - if is_mlx_available(): import mlx.core as mx @@ -1203,7 +1191,7 @@ def test_empty_input_string(self): output_tensor_type.append(mx.int32) if len(tokenizer_return_type) == 0: - self.skipTest(reason="No expected framework from PT, TF, JAX or MLX found") + self.skipTest(reason="No expected framework from PT, or MLX found") tokenizers = self.get_tokenizers() for tokenizer in tokenizers: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 6f4c6457ec5f..c0d025e1e23d 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -42,9 +42,7 @@ SpecialTokensMixin, Trainer, TrainingArguments, - is_flax_available, is_mlx_available, - is_tf_available, is_torch_available, logging, ) @@ -4703,18 +4701,6 @@ def test_empty_input_string(self): tokenizer_return_type.append("np") output_tensor_type.append(np.int64) - if is_tf_available(): - import tensorflow as tf - - tokenizer_return_type.append("tf") - output_tensor_type.append(tf.int32) - - if is_flax_available(): - import jax.numpy as jnp - - tokenizer_return_type.append("jax") - output_tensor_type.append(jnp.int32) - if is_mlx_available(): import mlx.core as mx @@ -4722,7 +4708,7 @@ def test_empty_input_string(self): output_tensor_type.append(mx.int32) if len(tokenizer_return_type) == 0: - self.skipTest(reason="No expected framework from PT, TF, JAX or MLX found") + self.skipTest(reason="No expected framework from PT, or MLX found") tokenizers = self.get_tokenizers() for tokenizer in tokenizers: From 3bf3a97b829e88c019fa5e230ce62eec948e5575 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 16:28:57 +0200 Subject: [PATCH 24/35] onnx configs --- docs/source/es/_toctree.yml | 2 - .../source/es/converting_tensorflow_models.md | 139 ---------------- docs/source/it/_toctree.yml | 2 - .../source/it/converting_tensorflow_models.md | 144 ----------------- docs/source/pt/_toctree.yml | 2 - .../source/pt/converting_tensorflow_models.md | 152 ------------------ .../modeling_dummy_bert.py | 1 - .../modular-transformers/modeling_roberta.py | 1 - src/transformers/configuration_utils.py | 3 +- .../feature_extraction_sequence_utils.py | 4 +- src/transformers/feature_extraction_utils.py | 2 +- src/transformers/image_processing_base.py | 2 +- .../models/albert/modeling_albert.py | 1 - .../models/align/modeling_align.py | 2 - .../models/altclip/modeling_altclip.py | 1 - .../models/bart/configuration_bart.py | 22 ++- src/transformers/models/beit/modeling_beit.py | 5 - src/transformers/models/bert/modeling_bert.py | 1 - .../modeling_bert_generation.py | 1 - .../models/big_bird/modeling_big_bird.py | 1 - .../configuration_bigbird_pegasus.py | 24 ++- src/transformers/models/bit/modeling_bit.py | 9 +- .../blenderbot/configuration_blenderbot.py | 22 ++- .../configuration_blenderbot_small.py | 24 ++- .../models/blip/modeling_blip_text.py | 2 - .../models/bloom/configuration_bloom.py | 8 +- .../bridgetower/modeling_bridgetower.py | 1 - src/transformers/models/bros/modeling_bros.py | 2 - .../models/camembert/modeling_camembert.py | 1 - .../models/canine/modeling_canine.py | 5 - .../configuration_chinese_clip.py | 11 +- .../chinese_clip/modeling_chinese_clip.py | 2 - src/transformers/models/clap/modeling_clap.py | 1 - .../models/clip/configuration_clip.py | 11 +- .../models/codegen/configuration_codegen.py | 7 +- .../models/convbert/modeling_convbert.py | 1 - .../models/convnext/modeling_convnext.py | 5 - .../models/convnextv2/modeling_convnextv2.py | 5 - src/transformers/models/cvt/modeling_cvt.py | 5 - .../models/data2vec/modeling_data2vec_text.py | 1 - .../data2vec/modeling_data2vec_vision.py | 5 - .../models/deberta/configuration_deberta.py | 7 +- .../deberta_v2/configuration_deberta_v2.py | 7 +- .../modeling_efficientformer.py | 5 - .../deprecated/jukebox/modeling_jukebox.py | 4 +- .../models/deprecated/mctct/modeling_mctct.py | 2 - .../models/deprecated/nat/modeling_nat.py | 5 - .../models/deprecated/nezha/modeling_nezha.py | 2 - .../deprecated/qdqbert/modeling_qdqbert.py | 2 - .../models/deprecated/realm/modeling_realm.py | 2 - .../models/deprecated/van/modeling_van.py | 5 - .../models/dinat/modeling_dinat.py | 5 - .../models/dinov2/modeling_dinov2.py | 5 - .../modeling_dinov2_with_registers.py | 5 - .../modeling_dinov3_convnext.py | 5 - .../models/dinov3_vit/modeling_dinov3_vit.py | 5 - .../models/donut/modeling_donut_swin.py | 5 - .../models/electra/modeling_electra.py | 2 - src/transformers/models/eomt/modeling_eomt.py | 5 - .../models/ernie/modeling_ernie.py | 2 - .../models/flava/modeling_flava.py | 2 - .../models/florence2/modeling_florence2.py | 5 - src/transformers/models/fnet/modeling_fnet.py | 2 - .../models/focalnet/modeling_focalnet.py | 5 - src/transformers/models/git/modeling_git.py | 2 - src/transformers/models/glpn/modeling_glpn.py | 5 - .../models/gpt2/configuration_gpt2.py | 5 +- .../models/gpt_neo/configuration_gpt_neo.py | 10 +- .../models/gptj/configuration_gptj.py | 5 +- .../grounding_dino/modeling_grounding_dino.py | 5 - .../models/groupvit/configuration_groupvit.py | 11 +- .../models/hiera/modeling_hiera.py | 5 - .../models/ibert/modeling_ibert.py | 2 - .../models/imagegpt/configuration_imagegpt.py | 11 +- .../models/kosmos2_5/modeling_kosmos2_5.py | 1 - .../models/layoutlm/configuration_layoutlm.py | 14 +- .../layoutlmv3/configuration_layoutlmv3.py | 10 +- src/transformers/models/lilt/modeling_lilt.py | 2 - .../longformer/configuration_longformer.py | 8 +- .../models/longformer/modeling_longformer.py | 2 - src/transformers/models/luke/modeling_luke.py | 2 - .../models/luke/tokenization_luke.py | 2 +- .../models/lxmert/modeling_lxmert.py | 1 - .../models/m2m_100/configuration_m2m_100.py | 12 +- .../models/marian/configuration_marian.py | 35 ++-- .../maskformer/modeling_maskformer_swin.py | 5 - .../models/mbart/configuration_mbart.py | 22 ++- .../megatron_bert/modeling_megatron_bert.py | 3 - .../models/mgp_str/modeling_mgp_str.py | 5 - .../models/mluke/tokenization_mluke.py | 2 +- .../modeling_mm_grounding_dino.py | 5 - src/transformers/models/mra/modeling_mra.py | 2 - .../nystromformer/modeling_nystromformer.py | 2 - .../models/owlvit/configuration_owlvit.py | 11 +- .../perceiver/configuration_perceiver.py | 9 +- .../models/poolformer/modeling_poolformer.py | 5 - src/transformers/models/pvt/modeling_pvt.py | 5 - .../models/pvt_v2/modeling_pvt_v2.py | 5 - .../models/rembert/modeling_rembert.py | 2 - .../models/roberta/modeling_roberta.py | 1 - .../modeling_roberta_prelayernorm.py | 1 - .../models/roc_bert/modeling_roc_bert.py | 2 - .../models/roformer/modeling_roformer.py | 1 - .../models/segformer/modeling_segformer.py | 5 - .../models/seggpt/modeling_seggpt.py | 5 - .../models/splinter/modeling_splinter.py | 2 - .../squeezebert/modeling_squeezebert.py | 2 - .../swiftformer/modeling_swiftformer.py | 5 - src/transformers/models/swin/modeling_swin.py | 5 - .../models/swin2sr/modeling_swin2sr.py | 5 - .../models/swinv2/modeling_swinv2.py | 5 - .../models/tapas/modeling_tapas.py | 2 - .../timesformer/modeling_timesformer.py | 5 - src/transformers/models/vilt/modeling_vilt.py | 2 - .../configuration_vision_encoder_decoder.py | 10 +- .../visual_bert/modeling_visual_bert.py | 3 - .../models/vitdet/modeling_vitdet.py | 5 - .../models/vjepa2/modeling_vjepa2.py | 5 - .../models/whisper/configuration_whisper.py | 10 +- .../models/x_clip/modeling_x_clip.py | 5 - .../xlm_roberta/modeling_xlm_roberta.py | 1 - .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 2 - src/transformers/models/xmod/modeling_xmod.py | 1 - src/transformers/models/yoso/modeling_yoso.py | 2 - src/transformers/onnx/config.py | 34 ++-- src/transformers/onnx/convert.py | 4 +- tests/utils/test_auto_docstring.py | 4 +- 127 files changed, 188 insertions(+), 913 deletions(-) delete mode 100644 docs/source/es/converting_tensorflow_models.md delete mode 100644 docs/source/it/converting_tensorflow_models.md delete mode 100644 docs/source/pt/converting_tensorflow_models.md diff --git a/docs/source/es/_toctree.yml b/docs/source/es/_toctree.yml index 85a9aec02e7d..d016c8ca88ec 100644 --- a/docs/source/es/_toctree.yml +++ b/docs/source/es/_toctree.yml @@ -64,8 +64,6 @@ title: Entrenador - local: sagemaker title: Ejecutar el entrenamiento en Amazon SageMaker - - local: converting_tensorflow_models - title: Convertir checkpoints de TensorFlow - local: serialization title: Exportar a ONNX - local: torchscript diff --git a/docs/source/es/converting_tensorflow_models.md b/docs/source/es/converting_tensorflow_models.md deleted file mode 100644 index 290f325b96c7..000000000000 --- a/docs/source/es/converting_tensorflow_models.md +++ /dev/null @@ -1,139 +0,0 @@ - - -# Convertir checkpoints de Tensorflow - -Te proporcionamos una interfaz de línea de comando (`CLI`, por sus siglas en inglés) para convertir puntos de control (_checkpoints_) originales de Bert/GPT/GPT-2/Transformer-XL/XLNet/XLM en modelos que se puedan cargar utilizando los métodos `from_pretrained` de la biblioteca. - - - -Desde 2.3.0, el script para convertir es parte de la CLI de transformers (**transformers**) disponible en cualquier instalación de transformers >= 2.3.0. - -La siguiente documentación refleja el formato para el comando **transformers convert**. - - - -## BERT - -Puedes convertir cualquier checkpoint de TensorFlow para BERT (en particular, [los modelos pre-entrenados y publicados por Google](https://github.com/google-research/bert#pre-trained-models)) en un archivo de PyTorch mediante el script [convert_bert_original_tf_checkpoint_to_pytorch.py](https://github.com/huggingface/transformers/tree/main/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py). - -Esta CLI toma como entrada un checkpoint de TensorFlow (tres archivos que comienzan con `bert_model.ckpt`) y el archivo de configuración asociado (`bert_config.json`), y crea un modelo PyTorch para esta configuración, carga los pesos del checkpoint de TensorFlow en el modelo de PyTorch y guarda el modelo resultante en un archivo estándar de PyTorch que se puede importar usando `from_pretrained()` (ve el ejemplo en [Tour rápido](quicktour), [run_glue.py](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification/run_glue.py)). - -Solo necesitas ejecutar este script **una vez** para convertir un modelo a PyTorch. Después, puedes ignorar el checkpoint de TensorFlow (los tres archivos que comienzan con `bert_model.ckpt`), pero asegúrate de conservar el archivo de configuración (`bert_config.json`) y el archivo de vocabulario (`vocab.txt`) ya que estos también son necesarios para el modelo en PyTorch. - -Para ejecutar este script deberás tener instalado TensorFlow y PyTorch (`pip install tensorflow`). El resto del repositorio solo requiere PyTorch. - -Aquí hay un ejemplo del proceso para convertir un modelo `BERT-Base Uncased` pre-entrenado: - -```bash -export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 - -transformers convert --model_type bert \ - --tf_checkpoint $BERT_BASE_DIR/bert_model.ckpt \ - --config $BERT_BASE_DIR/bert_config.json \ - --pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin -``` - -Puedes descargar los modelos pre-entrenados de Google para la conversión [aquí](https://github.com/google-research/bert#pre-trained-models). - -## ALBERT - -Convierte los checkpoints del modelo ALBERT de TensorFlow a PyTorch usando el script [convert_albert_original_tf_checkpoint_to_pytorch.py](https://github.com/huggingface/transformers/tree/main/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py). - -La CLI toma como entrada un checkpoint de TensorFlow (tres archivos que comienzan con `model.ckpt-best`) y el archivo de configuración adjunto (`albert_config.json`), luego crea y guarda un modelo de PyTorch. Para ejecutar esta conversión deberás tener instalados TensorFlow y PyTorch. - -Aquí hay un ejemplo del proceso para convertir un modelo `ALBERT Base` pre-entrenado: - -```bash -export ALBERT_BASE_DIR=/path/to/albert/albert_base - -transformers convert --model_type albert \ - --tf_checkpoint $ALBERT_BASE_DIR/model.ckpt-best \ - --config $ALBERT_BASE_DIR/albert_config.json \ - --pytorch_dump_output $ALBERT_BASE_DIR/pytorch_model.bin -``` - -Puedes descargar los modelos pre-entrenados de Google para la conversión [aquí](https://github.com/google-research/albert#pre-trained-models). - -## OpenAI GPT - -Este es un ejemplo del proceso para convertir un modelo OpenAI GPT pre-entrenado, asumiendo que tu checkpoint de NumPy se guarda con el mismo formato que el modelo pre-entrenado de OpenAI (más información [aquí](https://github.com/openai/finetune-transformer-lm)): - -```bash -export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights - -transformers convert --model_type gpt \ - --tf_checkpoint $OPENAI_GPT_CHECKPOINT_FOLDER_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \ - [--config OPENAI_GPT_CONFIG] \ - [--finetuning_task_name OPENAI_GPT_FINETUNED_TASK] \ -``` - -## OpenAI GPT-2 - -Aquí hay un ejemplo del proceso para convertir un modelo OpenAI GPT-2 pre-entrenado (más información [aquí](https://github.com/openai/gpt-2)): - -```bash -export OPENAI_GPT2_CHECKPOINT_PATH=/path/to/openai-community/gpt2/pretrained/weights - -transformers convert --model_type gpt2 \ - --tf_checkpoint $OPENAI_GPT2_CHECKPOINT_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \ - [--config OPENAI_GPT2_CONFIG] \ - [--finetuning_task_name OPENAI_GPT2_FINETUNED_TASK] -``` - -## XLNet - -Aquí hay un ejemplo del proceso para convertir un modelo XLNet pre-entrenado: - -```bash -export TRANSFO_XL_CHECKPOINT_PATH=/path/to/xlnet/checkpoint -export TRANSFO_XL_CONFIG_PATH=/path/to/xlnet/config - -transformers convert --model_type xlnet \ - --tf_checkpoint $TRANSFO_XL_CHECKPOINT_PATH \ - --config $TRANSFO_XL_CONFIG_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \ - [--finetuning_task_name XLNET_FINETUNED_TASK] \ -``` - -## XLM - -Aquí hay un ejemplo del proceso para convertir un modelo XLM pre-entrenado: - -```bash -export XLM_CHECKPOINT_PATH=/path/to/xlm/checkpoint - -transformers convert --model_type xlm \ - --tf_checkpoint $XLM_CHECKPOINT_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT - [--config XML_CONFIG] \ - [--finetuning_task_name XML_FINETUNED_TASK] -``` - -## T5 - -Aquí hay un ejemplo del proceso para convertir un modelo T5 pre-entrenado: - -```bash -export T5=/path/to/t5/uncased_L-12_H-768_A-12 - -transformers convert --model_type t5 \ - --tf_checkpoint $T5/t5_model.ckpt \ - --config $T5/t5_config.json \ - --pytorch_dump_output $T5/pytorch_model.bin -``` diff --git a/docs/source/it/_toctree.yml b/docs/source/it/_toctree.yml index 47d90f9a9a85..2ba1b8ecede3 100644 --- a/docs/source/it/_toctree.yml +++ b/docs/source/it/_toctree.yml @@ -29,8 +29,6 @@ title: Addestramento con script - local: multilingual title: Modelli multilingua per l'inferenza - - local: converting_tensorflow_models - title: Convertire modelli tensorflow - local: serialization title: Esporta modelli Transformers - local: perf_train_cpu diff --git a/docs/source/it/converting_tensorflow_models.md b/docs/source/it/converting_tensorflow_models.md deleted file mode 100644 index dace244fa6dd..000000000000 --- a/docs/source/it/converting_tensorflow_models.md +++ /dev/null @@ -1,144 +0,0 @@ - - -# Convertire checkpoint di Tensorflow - -È disponibile un'interfaccia a linea di comando per convertire gli originali checkpoint di Bert/GPT/GPT-2/Transformer-XL/XLNet/XLM -in modelli che possono essere caricati utilizzando i metodi `from_pretrained` della libreria. - - - -A partire dalla versione 2.3.0 lo script di conversione è parte di transformers CLI (**transformers**), disponibile in ogni installazione -di transformers >=2.3.0. - -La seguente documentazione riflette il formato dei comandi di **transformers convert**. - - - -## BERT - -Puoi convertire qualunque checkpoint Tensorflow di BERT (in particolare -[i modeli pre-allenati rilasciati da Google](https://github.com/google-research/bert#pre-trained-models)) -in un file di salvataggio Pytorch utilizzando lo script -[convert_bert_original_tf_checkpoint_to_pytorch.py](https://github.com/huggingface/transformers/tree/main/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py). - -Questo CLI prende come input un checkpoint di Tensorflow (tre files che iniziano con `bert_model.ckpt`) ed il relativo -file di configurazione (`bert_config.json`), crea un modello Pytorch per questa configurazione, carica i pesi dal -checkpoint di Tensorflow nel modello di Pytorch e salva il modello che ne risulta in un file di salvataggio standard di Pytorch che -può essere importato utilizzando `from_pretrained()` (vedi l'esempio nel -[quicktour](quicktour) , [run_glue.py](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification/run_glue.py) ). - -Devi soltanto lanciare questo script di conversione **una volta** per ottenere un modello Pytorch. Dopodichè, potrai tralasciare -il checkpoint di Tensorflow (i tre files che iniziano con `bert_model.ckpt`), ma assicurati di tenere il file di configurazione -(`bert_config.json`) ed il file di vocabolario (`vocab.txt`) in quanto queste componenti sono necessarie anche per il modello di Pytorch. - -Per lanciare questo specifico script di conversione avrai bisogno di un'installazione di Tensorflow e di Pytorch -(`pip install tensorflow`). Il resto della repository richiede soltanto Pytorch. - -Questo è un esempio del processo di conversione per un modello `BERT-Base Uncased` pre-allenato: - -```bash -export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 -transformers convert --model_type bert \ - --tf_checkpoint $BERT_BASE_DIR/bert_model.ckpt \ - --config $BERT_BASE_DIR/bert_config.json \ - --pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin -``` - -Puoi scaricare i modelli pre-allenati di Google per la conversione [qua](https://github.com/google-research/bert#pre-trained-models). - -## ALBERT - -Per il modello ALBERT, converti checkpoint di Tensoflow in Pytorch utilizzando lo script -[convert_albert_original_tf_checkpoint_to_pytorch.py](https://github.com/huggingface/transformers/tree/main/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py). - -Il CLI prende come input un checkpoint di Tensorflow (tre files che iniziano con `model.ckpt-best`) e i relativi file di -configurazione (`albert_config.json`), dopodichè crea e salva un modello Pytorch. Per lanciare questa conversione -avrai bisogno di un'installazione di Tensorflow e di Pytorch. - -Ecco un esempio del procedimento di conversione di un modello `ALBERT Base` pre-allenato: - -```bash -export ALBERT_BASE_DIR=/path/to/albert/albert_base -transformers convert --model_type albert \ - --tf_checkpoint $ALBERT_BASE_DIR/model.ckpt-best \ - --config $ALBERT_BASE_DIR/albert_config.json \ - --pytorch_dump_output $ALBERT_BASE_DIR/pytorch_model.bin -``` - -Puoi scaricare i modelli pre-allenati di Google per la conversione [qui](https://github.com/google-research/albert#pre-trained-models). - -## OpenAI GPT - -Ecco un esempio del processo di conversione di un modello OpenAI GPT pre-allenato, assumendo che il tuo checkpoint di NumPy -sia salvato nello stesso formato dei modelli pre-allenati OpenAI (vedi [qui](https://github.com/openai/finetune-transformer-lm)): -```bash -export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights -transformers convert --model_type gpt \ - --tf_checkpoint $OPENAI_GPT_CHECKPOINT_FOLDER_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \ - [--config OPENAI_GPT_CONFIG] \ - [--finetuning_task_name OPENAI_GPT_FINETUNED_TASK] \ -``` - -## OpenAI GPT-2 - -Ecco un esempio del processo di conversione di un modello OpenAI GPT-2 pre-allenato (vedi [qui](https://github.com/openai/gpt-2)): - -```bash -export OPENAI_GPT2_CHECKPOINT_PATH=/path/to/openai-community/gpt2/pretrained/weights -transformers convert --model_type gpt2 \ - --tf_checkpoint $OPENAI_GPT2_CHECKPOINT_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \ - [--config OPENAI_GPT2_CONFIG] \ - [--finetuning_task_name OPENAI_GPT2_FINETUNED_TASK] -``` - -## XLNet - -Ecco un esempio del processo di conversione di un modello XLNet pre-allenato: - -```bash -export TRANSFO_XL_CHECKPOINT_PATH=/path/to/xlnet/checkpoint -export TRANSFO_XL_CONFIG_PATH=/path/to/xlnet/config -transformers convert --model_type xlnet \ - --tf_checkpoint $TRANSFO_XL_CHECKPOINT_PATH \ - --config $TRANSFO_XL_CONFIG_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \ - [--finetuning_task_name XLNET_FINETUNED_TASK] \ -``` - -## XLM - -Ecco un esempio del processo di conversione di un modello XLM pre-allenato: - -```bash -export XLM_CHECKPOINT_PATH=/path/to/xlm/checkpoint -transformers convert --model_type xlm \ - --tf_checkpoint $XLM_CHECKPOINT_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT - [--config XML_CONFIG] \ - [--finetuning_task_name XML_FINETUNED_TASK] -``` - -## T5 - -Ecco un esempio del processo di conversione di un modello T5 pre-allenato: - -```bash -export T5=/path/to/t5/uncased_L-12_H-768_A-12 -transformers convert --model_type t5 \ - --tf_checkpoint $T5/t5_model.ckpt \ - --config $T5/t5_config.json \ - --pytorch_dump_output $T5/pytorch_model.bin -``` diff --git a/docs/source/pt/_toctree.yml b/docs/source/pt/_toctree.yml index d042168f7b9b..c525a2a4faa1 100644 --- a/docs/source/pt/_toctree.yml +++ b/docs/source/pt/_toctree.yml @@ -23,8 +23,6 @@ title: Compartilhando modelos customizados - local: run_scripts title: Treinamento a partir de um script - - local: converting_tensorflow_models - title: Convertendo checkpoints do TensorFlow para Pytorch - local: serialization title: Exportando modelos para ONNX - sections: diff --git a/docs/source/pt/converting_tensorflow_models.md b/docs/source/pt/converting_tensorflow_models.md deleted file mode 100644 index 446acd62ea8f..000000000000 --- a/docs/source/pt/converting_tensorflow_models.md +++ /dev/null @@ -1,152 +0,0 @@ - - -# Convertendo checkpoints do TensorFlow para Pytorch - -Uma interface de linha de comando é fornecida para converter os checkpoints originais Bert/GPT/GPT-2/Transformer-XL/XLNet/XLM em modelos -que podem ser carregados usando os métodos `from_pretrained` da biblioteca. - - - -A partir da versão 2.3.0 o script de conversão agora faz parte do transformers CLI (**transformers**) disponível em qualquer instalação -transformers >= 2.3.0. - -A documentação abaixo reflete o formato do comando **transformers convert**. - - - -## BERT - -Você pode converter qualquer checkpoint do BERT em TensorFlow (em particular [os modelos pré-treinados lançados pelo Google](https://github.com/google-research/bert#pre-trained-models)) em um arquivo PyTorch usando um -[convert_bert_original_tf_checkpoint_to_pytorch.py](https://github.com/huggingface/transformers/tree/main/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py) script. - -Esta Interface de Linha de Comando (CLI) recebe como entrada um checkpoint do TensorFlow (três arquivos começando com `bert_model.ckpt`) e o -arquivo de configuração (`bert_config.json`), e então cria um modelo PyTorch para esta configuração, carrega os pesos -do checkpoint do TensorFlow no modelo PyTorch e salva o modelo resultante em um arquivo PyTorch que pode -ser importado usando `from_pretrained()` (veja o exemplo em [quicktour](quicktour) , [run_glue.py](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification/run_glue.py) ). - -Você só precisa executar este script de conversão **uma vez** para obter um modelo PyTorch. Você pode então desconsiderar o checkpoint em - TensorFlow (os três arquivos começando com `bert_model.ckpt`), mas certifique-se de manter o arquivo de configuração (\ -`bert_config.json`) e o arquivo de vocabulário (`vocab.txt`), pois eles também são necessários para o modelo PyTorch. - -Para executar este script de conversão específico, você precisará ter o TensorFlow e o PyTorch instalados (`pip install tensorflow`). O resto do repositório requer apenas o PyTorch. - -Aqui está um exemplo do processo de conversão para um modelo `BERT-Base Uncased` pré-treinado: - -```bash -export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 - -transformers convert --model_type bert \ - --tf_checkpoint $BERT_BASE_DIR/bert_model.ckpt \ - --config $BERT_BASE_DIR/bert_config.json \ - --pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin -``` - -Você pode baixar os modelos pré-treinados do Google para a conversão [aqui](https://github.com/google-research/bert#pre-trained-models). - -## ALBERT - -Converta os checkpoints do modelo ALBERT em TensorFlow para PyTorch usando o -[convert_albert_original_tf_checkpoint_to_pytorch.py](https://github.com/huggingface/transformers/tree/main/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py) script. - -A Interface de Linha de Comando (CLI) recebe como entrada um checkpoint do TensorFlow (três arquivos começando com `model.ckpt-best`) e o -arquivo de configuração (`albert_config.json`), então cria e salva um modelo PyTorch. Para executar esta conversão, você -precisa ter o TensorFlow e o PyTorch instalados. - -Aqui está um exemplo do processo de conversão para o modelo `ALBERT Base` pré-treinado: - -```bash -export ALBERT_BASE_DIR=/path/to/albert/albert_base - -transformers convert --model_type albert \ - --tf_checkpoint $ALBERT_BASE_DIR/model.ckpt-best \ - --config $ALBERT_BASE_DIR/albert_config.json \ - --pytorch_dump_output $ALBERT_BASE_DIR/pytorch_model.bin -``` - -Você pode baixar os modelos pré-treinados do Google para a conversão [aqui](https://github.com/google-research/albert#pre-trained-models). - -## OpenAI GPT - -Aqui está um exemplo do processo de conversão para um modelo OpenAI GPT pré-treinado, supondo que seu checkpoint NumPy -foi salvo com o mesmo formato do modelo pré-treinado OpenAI (veja [aqui](https://github.com/openai/finetune-transformer-lm)\ -) - -```bash -export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights - -transformers convert --model_type gpt \ - --tf_checkpoint $OPENAI_GPT_CHECKPOINT_FOLDER_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \ - [--config OPENAI_GPT_CONFIG] \ - [--finetuning_task_name OPENAI_GPT_FINETUNED_TASK] \ -``` - -## OpenAI GPT-2 - -Aqui está um exemplo do processo de conversão para um modelo OpenAI GPT-2 pré-treinado (consulte [aqui](https://github.com/openai/gpt-2)) - -```bash -export OPENAI_GPT2_CHECKPOINT_PATH=/path/to/openai-community/gpt2/pretrained/weights - -transformers convert --model_type gpt2 \ - --tf_checkpoint $OPENAI_GPT2_CHECKPOINT_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \ - [--config OPENAI_GPT2_CONFIG] \ - [--finetuning_task_name OPENAI_GPT2_FINETUNED_TASK] -``` - -## XLNet - -Aqui está um exemplo do processo de conversão para um modelo XLNet pré-treinado: - -```bash -export TRANSFO_XL_CHECKPOINT_PATH=/path/to/xlnet/checkpoint -export TRANSFO_XL_CONFIG_PATH=/path/to/xlnet/config - -transformers convert --model_type xlnet \ - --tf_checkpoint $TRANSFO_XL_CHECKPOINT_PATH \ - --config $TRANSFO_XL_CONFIG_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \ - [--finetuning_task_name XLNET_FINETUNED_TASK] \ -``` - -## XLM - -Aqui está um exemplo do processo de conversão para um modelo XLM pré-treinado: - -```bash -export XLM_CHECKPOINT_PATH=/path/to/xlm/checkpoint - -transformers convert --model_type xlm \ - --tf_checkpoint $XLM_CHECKPOINT_PATH \ - --pytorch_dump_output $PYTORCH_DUMP_OUTPUT - [--config XML_CONFIG] \ - [--finetuning_task_name XML_FINETUNED_TASK] -``` - -## T5 - -Aqui está um exemplo do processo de conversão para um modelo T5 pré-treinado: - -```bash -export T5=/path/to/t5/uncased_L-12_H-768_A-12 - -transformers convert --model_type t5 \ - --tf_checkpoint $T5/t5_model.ckpt \ - --config $T5/t5_config.json \ - --pytorch_dump_output $T5/pytorch_model.bin -``` diff --git a/examples/modular-transformers/modeling_dummy_bert.py b/examples/modular-transformers/modeling_dummy_bert.py index bf4b7ce94ee5..9df092f73e6e 100644 --- a/examples/modular-transformers/modeling_dummy_bert.py +++ b/examples/modular-transformers/modeling_dummy_bert.py @@ -34,7 +34,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/examples/modular-transformers/modeling_roberta.py b/examples/modular-transformers/modeling_roberta.py index f9db8677c0d5..2ae39a555892 100644 --- a/examples/modular-transformers/modeling_roberta.py +++ b/examples/modular-transformers/modeling_roberta.py @@ -36,7 +36,6 @@ def __init__(self, config): ) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 933d21b2b436..6d06d72353e8 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -139,8 +139,7 @@ class PretrainedConfig(PushToHubMixin): architectures (`list[str]`, *optional*): Model architectures that can be used with the model pretrained weights. finetuning_task (`str`, *optional*): - Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow - or PyTorch) checkpoint. + Name of the task used to fine-tune the model. id2label (`dict[int, str]`, *optional*): A map from index (for instance prediction index, or target index) to label. label2id (`dict[str, int]`, *optional*): diff --git a/src/transformers/feature_extraction_sequence_utils.py b/src/transformers/feature_extraction_sequence_utils.py index f4a0fe30441d..1a48062cb5c1 100644 --- a/src/transformers/feature_extraction_sequence_utils.py +++ b/src/transformers/feature_extraction_sequence_utils.py @@ -74,7 +74,7 @@ def pad( - If the `processed_features` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the + If the `processed_features` passed are dictionary of numpy arrays or PyTorch tensors the result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of PyTorch tensors, you will lose the specific device of your tensors however. @@ -87,7 +87,7 @@ def pad( list[float]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader collate function. - Instead of `list[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), + Instead of `list[float]` you can have tensors (numpy arrays or PyTorch tensors), see the note above for the return type. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 20d23d78db1c..fd9eb56941b9 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -68,7 +68,7 @@ class BatchFeature(UserDict): Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask', etc.). tensor_type (`Union[None, str, TensorType]`, *optional*): - You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at initialization. """ diff --git a/src/transformers/image_processing_base.py b/src/transformers/image_processing_base.py index dfe94ffd0df7..8bd65e9bc3ce 100644 --- a/src/transformers/image_processing_base.py +++ b/src/transformers/image_processing_base.py @@ -55,7 +55,7 @@ class BatchFeature(BaseBatchFeature): data (`dict`): Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.). tensor_type (`Union[None, str, TensorType]`, *optional*): - You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at initialization. """ diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 0b4f9c70d914..c3d1dc540223 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -57,7 +57,6 @@ def __init__(self, config: AlbertConfig): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index c226a3b36ac6..839856b92119 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -516,8 +516,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 9abac9d7e9b2..5cae61f12d7f 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -101,7 +101,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index 90781feab3b5..e560bfa7d4a2 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -17,13 +17,13 @@ import warnings from collections import OrderedDict from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from ...onnx.utils import compute_effective_axis_dimension -from ...utils import TensorType, is_torch_available, logging +from ...utils import is_torch_available, logging logger = logging.get_logger(__name__) @@ -244,16 +244,15 @@ def _generate_dummy_inputs_for_default_and_seq2seq_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size, seq_length, is_pair ) # Generate decoder inputs decoder_seq_length = seq_length if not self.use_past else 1 decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, decoder_seq_length, is_pair, framework + tokenizer, batch_size, decoder_seq_length, is_pair ) decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} common_inputs = dict(**encoder_inputs, **decoder_inputs) @@ -312,10 +311,9 @@ def _generate_dummy_inputs_for_causal_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size, seq_length, is_pair ) if self.use_past: @@ -350,7 +348,6 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: # Copied from OnnxConfig.generate_dummy_inputs # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. @@ -367,7 +364,7 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( # Generate dummy inputs according to compute batch and sequence dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size - common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + common_inputs = dict(tokenizer(dummy_input, return_tensors="pt")) return common_inputs def generate_dummy_inputs( @@ -376,20 +373,19 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: if self.task in ["default", "seq2seq-lm"]: common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) elif self.task == "causal-lm": common_inputs = self._generate_dummy_inputs_for_causal_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) else: common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) return common_inputs diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index cb4e0d712651..18f4fad10db9 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -63,11 +63,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 3502e0094a3e..186d13bb7541 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -59,7 +59,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index c42093237134..c1b09041cd74 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -444,7 +444,6 @@ def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 2f776a6205f5..69dc11a7cb69 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -75,7 +75,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py index 29b481c78ad1..dc32c34e0d25 100644 --- a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py @@ -16,13 +16,13 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from ...onnx.utils import compute_effective_axis_dimension -from ...utils import TensorType, is_torch_available, logging +from ...utils import is_torch_available, logging logger = logging.get_logger(__name__) @@ -186,7 +186,7 @@ def __init__( ) -# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig +# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig with Bart->BigBirdPegasus class BigBirdPegasusOnnxConfig(OnnxSeq2SeqConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: @@ -251,16 +251,15 @@ def _generate_dummy_inputs_for_default_and_seq2seq_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size, seq_length, is_pair ) # Generate decoder inputs decoder_seq_length = seq_length if not self.use_past else 1 decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, decoder_seq_length, is_pair, framework + tokenizer, batch_size, decoder_seq_length, is_pair ) decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} common_inputs = dict(**encoder_inputs, **decoder_inputs) @@ -319,10 +318,9 @@ def _generate_dummy_inputs_for_causal_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size, seq_length, is_pair ) if self.use_past: @@ -357,7 +355,6 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: # Copied from OnnxConfig.generate_dummy_inputs # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. @@ -374,7 +371,7 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( # Generate dummy inputs according to compute batch and sequence dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size - common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + common_inputs = dict(tokenizer(dummy_input, return_tensors="pt")) return common_inputs def generate_dummy_inputs( @@ -383,20 +380,19 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: if self.task in ["default", "seq2seq-lm"]: common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) elif self.task == "causal-lm": common_inputs = self._generate_dummy_inputs_for_causal_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) else: common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) return common_inputs diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index 1e491f06eae6..616c6d31d339 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -80,7 +80,7 @@ def get_padding_value(padding=None, kernel_size=7, stride=1, dilation=1) -> tupl class WeightStandardizedConv2d(nn.Conv2d): - """Conv2d with Weight Standardization. Includes TensorFlow compatible SAME padding. Used for ViT Hybrid model. + """Conv2d with Weight Standardization. Used for ViT Hybrid model. Paper: [Micro-Batch Training with Batch-Channel Normalization and Weight Standardization](https://huggingface.co/papers/1903.10520v2) @@ -197,8 +197,6 @@ def forward(self, input): class BitMaxPool2d(nn.MaxPool2d): - """Tensorflow like 'SAME' wrapper for 2D max pooling""" - def __init__( self, kernel_size: int, @@ -280,11 +278,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/blenderbot/configuration_blenderbot.py b/src/transformers/models/blenderbot/configuration_blenderbot.py index 44287991375a..8e4e4812aafd 100644 --- a/src/transformers/models/blenderbot/configuration_blenderbot.py +++ b/src/transformers/models/blenderbot/configuration_blenderbot.py @@ -16,11 +16,11 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig -from ...file_utils import TensorType, is_torch_available +from ...file_utils import is_torch_available from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from ...onnx.utils import compute_effective_axis_dimension from ...utils import logging @@ -228,15 +228,14 @@ def _generate_dummy_inputs_for_default_and_seq2seq_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size, seq_length, is_pair ) # Generate decoder inputs decoder_seq_length = seq_length if not self.use_past else 1 decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, decoder_seq_length, is_pair, framework + tokenizer, batch_size, decoder_seq_length, is_pair ) decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} common_inputs = dict(**encoder_inputs, **decoder_inputs) @@ -285,10 +284,9 @@ def _generate_dummy_inputs_for_causal_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size, seq_length, is_pair ) if self.use_past: @@ -322,7 +320,6 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: # Copied from OnnxConfig.generate_dummy_inputs # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. @@ -339,7 +336,7 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( # Generate dummy inputs according to compute batch and sequence dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size - common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + common_inputs = dict(tokenizer(dummy_input, return_tensors="pt")) return common_inputs # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.generate_dummy_inputs @@ -349,20 +346,19 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: if self.task in ["default", "seq2seq-lm"]: common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) elif self.task == "causal-lm": common_inputs = self._generate_dummy_inputs_for_causal_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) else: common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) return common_inputs diff --git a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py index 6d43b975e5ba..6cd7f7275c17 100644 --- a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py @@ -16,11 +16,11 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig -from ...file_utils import TensorType, is_torch_available +from ...file_utils import is_torch_available from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from ...onnx.utils import compute_effective_axis_dimension from ...utils import logging @@ -164,7 +164,7 @@ def __init__( ) -# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig +# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig with Bart->BlenderbotSmall class BlenderbotSmallOnnxConfig(OnnxSeq2SeqConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: @@ -229,16 +229,15 @@ def _generate_dummy_inputs_for_default_and_seq2seq_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size, seq_length, is_pair ) # Generate decoder inputs decoder_seq_length = seq_length if not self.use_past else 1 decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, decoder_seq_length, is_pair, framework + tokenizer, batch_size, decoder_seq_length, is_pair ) decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} common_inputs = dict(**encoder_inputs, **decoder_inputs) @@ -297,10 +296,9 @@ def _generate_dummy_inputs_for_causal_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size, seq_length, is_pair ) if self.use_past: @@ -335,7 +333,6 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: # Copied from OnnxConfig.generate_dummy_inputs # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. @@ -352,7 +349,7 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( # Generate dummy inputs according to compute batch and sequence dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size - common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + common_inputs = dict(tokenizer(dummy_input, return_tensors="pt")) return common_inputs def generate_dummy_inputs( @@ -361,20 +358,19 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: if self.task in ["default", "seq2seq-lm"]: common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) elif self.task == "causal-lm": common_inputs = self._generate_dummy_inputs_for_causal_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) else: common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) return common_inputs diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 6f1f58c75334..427adaf659db 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -49,8 +49,6 @@ def __init__(self, config): self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/bloom/configuration_bloom.py b/src/transformers/models/bloom/configuration_bloom.py index 74748c113041..8d5fa7656a73 100644 --- a/src/transformers/models/bloom/configuration_bloom.py +++ b/src/transformers/models/bloom/configuration_bloom.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: - from ... import PreTrainedTokenizer, TensorType + from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfigWithPast, PatchingSpec @@ -187,10 +187,12 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional["TensorType"] = None, ) -> Mapping[str, Any]: common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, + batch_size=batch_size, + seq_length=seq_length, + is_pair=is_pair, ) # We need to order the input in the way they appears in the forward() diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 3f91eb91ae4d..97fcc469a4c6 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -838,7 +838,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index d01a4c5a1c6d..7bb9e0ac762a 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -126,8 +126,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index eb5439d93a3d..aa86eb18d652 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -60,7 +60,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 9acbb476c2f8..545919dc7b77 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -99,7 +99,6 @@ def __init__(self, config): self.char_position_embeddings = nn.Embedding(config.num_hash_buckets, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -197,8 +196,6 @@ def __init__(self, config): ) self.activation = ACT2FN[config.hidden_act] - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, char_encoding: torch.Tensor) -> torch.Tensor: @@ -245,8 +242,6 @@ def __init__(self, config): stride=1, ) self.activation = ACT2FN[config.hidden_act] - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/chinese_clip/configuration_chinese_clip.py b/src/transformers/models/chinese_clip/configuration_chinese_clip.py index e7c98d0d2d9f..5b9c31965585 100644 --- a/src/transformers/models/chinese_clip/configuration_chinese_clip.py +++ b/src/transformers/models/chinese_clip/configuration_chinese_clip.py @@ -16,12 +16,11 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...processing_utils import ProcessorMixin - from ...utils import TensorType from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig @@ -405,13 +404,15 @@ def generate_dummy_inputs( processor: "ProcessorMixin", batch_size: int = -1, seq_length: int = -1, - framework: Optional["TensorType"] = None, ) -> Mapping[str, Any]: text_input_dict = super().generate_dummy_inputs( - processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework + processor.tokenizer, + batch_size=batch_size, + seq_length=seq_length, ) image_input_dict = super().generate_dummy_inputs( - processor.image_processor, batch_size=batch_size, framework=framework + processor.image_processor, + batch_size=batch_size, ) return {**text_input_dict, **image_input_dict} diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index a0b461ab3ed3..a689886abc37 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -98,8 +98,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 6c8633788b63..b8983eecf035 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1008,7 +1008,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/clip/configuration_clip.py b/src/transformers/models/clip/configuration_clip.py index 0b4fe6ba37f6..22c245485a0d 100644 --- a/src/transformers/models/clip/configuration_clip.py +++ b/src/transformers/models/clip/configuration_clip.py @@ -16,12 +16,11 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...processing_utils import ProcessorMixin - from ...utils import TensorType from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig @@ -393,13 +392,15 @@ def generate_dummy_inputs( processor: "ProcessorMixin", batch_size: int = -1, seq_length: int = -1, - framework: Optional["TensorType"] = None, ) -> Mapping[str, Any]: text_input_dict = super().generate_dummy_inputs( - processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework + processor.tokenizer, + batch_size=batch_size, + seq_length=seq_length, ) image_input_dict = super().generate_dummy_inputs( - processor.image_processor, batch_size=batch_size, framework=framework + processor.image_processor, + batch_size=batch_size, ) return {**text_input_dict, **image_input_dict} diff --git a/src/transformers/models/codegen/configuration_codegen.py b/src/transformers/models/codegen/configuration_codegen.py index 6a9ab842710c..658f3cfca1ac 100644 --- a/src/transformers/models/codegen/configuration_codegen.py +++ b/src/transformers/models/codegen/configuration_codegen.py @@ -18,7 +18,7 @@ from collections.abc import Mapping from typing import Any, Optional -from ... import PreTrainedTokenizer, TensorType, is_torch_available +from ... import PreTrainedTokenizer, is_torch_available from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfigWithPast, PatchingSpec from ...utils import logging @@ -146,7 +146,7 @@ def __init__( ) -# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig +# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig with GPT2->CodeGen class CodeGenOnnxConfig(OnnxConfigWithPast): def __init__( self, @@ -185,10 +185,9 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) # We need to order the input in the way they appears in the forward() diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index a5bc7912540e..5f4dd419b4fc 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -52,7 +52,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index 3120c140d2ed..d859c89ecb97 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -41,11 +41,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index bfa5338f5e86..8fb27efbe502 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -41,11 +41,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index 9d935ee84893..bd27b2db7f24 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -54,11 +54,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 06da51b8f1de..0fbe9e4802ad 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -61,7 +61,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index f214f8eb6a0b..c5229c3a7540 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -62,11 +62,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/deberta/configuration_deberta.py b/src/transformers/models/deberta/configuration_deberta.py index 3e23a73a8c38..49015eb7cc5b 100644 --- a/src/transformers/models/deberta/configuration_deberta.py +++ b/src/transformers/models/deberta/configuration_deberta.py @@ -16,7 +16,7 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig @@ -24,7 +24,7 @@ if TYPE_CHECKING: - from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType + from ... import FeatureExtractionMixin, PreTrainedTokenizerBase logger = logging.get_logger(__name__) @@ -185,13 +185,12 @@ def generate_dummy_inputs( seq_length: int = -1, num_choices: int = -1, is_pair: bool = False, - framework: Optional["TensorType"] = None, num_channels: int = 3, image_width: int = 40, image_height: int = 40, tokenizer: "PreTrainedTokenizerBase" = None, ) -> Mapping[str, Any]: - dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor, framework=framework) + dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor) if self._config.type_vocab_size == 0 and "token_type_ids" in dummy_inputs: del dummy_inputs["token_type_ids"] return dummy_inputs diff --git a/src/transformers/models/deberta_v2/configuration_deberta_v2.py b/src/transformers/models/deberta_v2/configuration_deberta_v2.py index 5189cfd53ae7..43576e815d07 100644 --- a/src/transformers/models/deberta_v2/configuration_deberta_v2.py +++ b/src/transformers/models/deberta_v2/configuration_deberta_v2.py @@ -16,7 +16,7 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig @@ -24,7 +24,7 @@ if TYPE_CHECKING: - from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType + from ... import FeatureExtractionMixin, PreTrainedTokenizerBase logger = logging.get_logger(__name__) @@ -184,13 +184,12 @@ def generate_dummy_inputs( seq_length: int = -1, num_choices: int = -1, is_pair: bool = False, - framework: Optional["TensorType"] = None, num_channels: int = 3, image_width: int = 40, image_height: int = 40, tokenizer: "PreTrainedTokenizerBase" = None, ) -> Mapping[str, Any]: - dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor, framework=framework) + dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor) if self._config.type_vocab_size == 0 and "token_type_ids" in dummy_inputs: del dummy_inputs["token_type_ids"] return dummy_inputs diff --git a/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py b/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py index d35d3e82c007..2167df912d87 100644 --- a/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py +++ b/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py @@ -241,11 +241,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/deprecated/jukebox/modeling_jukebox.py b/src/transformers/models/deprecated/jukebox/modeling_jukebox.py index f928d49cf5f7..daaa4b2ee489 100755 --- a/src/transformers/models/deprecated/jukebox/modeling_jukebox.py +++ b/src/transformers/models/deprecated/jukebox/modeling_jukebox.py @@ -907,7 +907,7 @@ def _attn(self, query_states, key_states, value_states, sample): def merge_heads(self, hidden_states): hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() new_hidden_states_shape = (*hidden_states.size()[:-2], hidden_states.size(-2) * hidden_states.size(-1)) - return hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct merge_states + return hidden_states.view(*new_hidden_states_shape) def split_heads(self, hidden_states, is_key=False): new_hidden_states_shape = ( @@ -915,7 +915,7 @@ def split_heads(self, hidden_states, is_key=False): self.n_heads, hidden_states.size(-1) // self.n_heads, ) - hidden_states = hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct split_states + hidden_states = hidden_states.view(*new_hidden_states_shape) if is_key: return hidden_states.permute(0, 2, 3, 1) else: diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index 253b09c1c43c..f4982935336d 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -118,8 +118,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = MCTCTLayerNorm() self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/deprecated/nat/modeling_nat.py b/src/transformers/models/deprecated/nat/modeling_nat.py index a619cdb11225..935ae89ef966 100644 --- a/src/transformers/models/deprecated/nat/modeling_nat.py +++ b/src/transformers/models/deprecated/nat/modeling_nat.py @@ -258,11 +258,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/deprecated/nezha/modeling_nezha.py b/src/transformers/models/deprecated/nezha/modeling_nezha.py index 5fcc0318a50a..eaf1cedfed32 100644 --- a/src/transformers/models/deprecated/nezha/modeling_nezha.py +++ b/src/transformers/models/deprecated/nezha/modeling_nezha.py @@ -96,8 +96,6 @@ def __init__(self, config): self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.register_buffer( diff --git a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py index 57ffcae80e56..f522a1d72154 100755 --- a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py @@ -79,8 +79,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/deprecated/realm/modeling_realm.py b/src/transformers/models/deprecated/realm/modeling_realm.py index bc177fcde7be..69bab60f6803 100644 --- a/src/transformers/models/deprecated/realm/modeling_realm.py +++ b/src/transformers/models/deprecated/realm/modeling_realm.py @@ -54,8 +54,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/deprecated/van/modeling_van.py b/src/transformers/models/deprecated/van/modeling_van.py index c0fc0bc1a637..6ee0e881e558 100644 --- a/src/transformers/models/deprecated/van/modeling_van.py +++ b/src/transformers/models/deprecated/van/modeling_van.py @@ -50,11 +50,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 4b7ec37b0ea8..80349f29592d 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -214,11 +214,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 0a9a2cba1da7..f84d442a3efc 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -297,11 +297,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index a02ac4c58476..042c21babd19 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -313,11 +313,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py b/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py index df2ef491192c..da68b17c8587 100644 --- a/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +++ b/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py @@ -38,11 +38,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py index dbea73e6caf5..76e365903082 100644 --- a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py @@ -329,11 +329,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index c5736b16183b..21f400cb2b68 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -334,11 +334,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index c944fb45a38d..d3b47ea55b79 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -55,8 +55,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/eomt/modeling_eomt.py b/src/transformers/models/eomt/modeling_eomt.py index 3e979040388d..e7e1624c1406 100644 --- a/src/transformers/models/eomt/modeling_eomt.py +++ b/src/transformers/models/eomt/modeling_eomt.py @@ -808,11 +808,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 7cbce6b2d20b..f244efb3f01c 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -60,8 +60,6 @@ def __init__(self, config): if config.use_task_id: self.task_type_embeddings = nn.Embedding(config.task_type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 266c3e96af5a..8705515f4270 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -375,8 +375,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/florence2/modeling_florence2.py b/src/transformers/models/florence2/modeling_florence2.py index 763756faf73f..64947dea1285 100644 --- a/src/transformers/models/florence2/modeling_florence2.py +++ b/src/transformers/models/florence2/modeling_florence2.py @@ -53,11 +53,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 2ad09a3b268b..3ed84887ce6f 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -91,8 +91,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # NOTE: This is the project layer and will be needed. The original code allows for different embedding and different model dimensions. self.projection = nn.Linear(config.hidden_size, config.hidden_size) diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index ed31b5deb527..9501db673037 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -247,11 +247,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 4122b7a0df79..6b99722aa6a8 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -76,8 +76,6 @@ def __init__(self, config): self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index e326750743a1..dc82249f2b7a 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -36,11 +36,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index db5151a2ba15..0fa19ada1c90 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -19,7 +19,7 @@ from collections.abc import Mapping from typing import Any, Optional -from ... import PreTrainedTokenizer, TensorType, is_torch_available +from ... import PreTrainedTokenizer, is_torch_available from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfigWithPast, PatchingSpec from ...utils import logging @@ -228,10 +228,9 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) # We need to order the input in the way they appears in the forward() diff --git a/src/transformers/models/gpt_neo/configuration_gpt_neo.py b/src/transformers/models/gpt_neo/configuration_gpt_neo.py index 875a170277d2..a9bbfcd33ef8 100644 --- a/src/transformers/models/gpt_neo/configuration_gpt_neo.py +++ b/src/transformers/models/gpt_neo/configuration_gpt_neo.py @@ -16,9 +16,9 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import Any, Optional +from typing import Any -from ... import PreTrainedTokenizer, TensorType, is_torch_available +from ... import PreTrainedTokenizer, is_torch_available from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfigWithPast from ...utils import logging @@ -227,10 +227,12 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, + batch_size=batch_size, + seq_length=seq_length, + is_pair=is_pair, ) # We need to order the input in the way they appears in the forward() diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py index 68c690996880..278bfbf0be96 100644 --- a/src/transformers/models/gptj/configuration_gptj.py +++ b/src/transformers/models/gptj/configuration_gptj.py @@ -18,7 +18,7 @@ from collections.abc import Mapping from typing import Any, Optional -from ... import PreTrainedTokenizer, TensorType, is_torch_available +from ... import PreTrainedTokenizer, is_torch_available from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfigWithPast, PatchingSpec from ...utils import logging @@ -174,10 +174,9 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) # We need to order the input in the way they appears in the forward() diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 5d674caca6fa..727749548042 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -861,11 +861,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/groupvit/configuration_groupvit.py b/src/transformers/models/groupvit/configuration_groupvit.py index d17288ede723..cd9fb2d0469e 100644 --- a/src/transformers/models/groupvit/configuration_groupvit.py +++ b/src/transformers/models/groupvit/configuration_groupvit.py @@ -16,7 +16,7 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig @@ -25,7 +25,6 @@ if TYPE_CHECKING: from ...processing_utils import ProcessorMixin - from ...utils import TensorType logger = logging.get_logger(__name__) @@ -389,13 +388,15 @@ def generate_dummy_inputs( processor: "ProcessorMixin", batch_size: int = -1, seq_length: int = -1, - framework: Optional["TensorType"] = None, ) -> Mapping[str, Any]: text_input_dict = super().generate_dummy_inputs( - processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework + processor.tokenizer, + batch_size=batch_size, + seq_length=seq_length, ) image_input_dict = super().generate_dummy_inputs( - processor.image_processor, batch_size=batch_size, framework=framework + processor.image_processor, + batch_size=batch_size, ) return {**text_input_dict, **image_input_dict} diff --git a/src/transformers/models/hiera/modeling_hiera.py b/src/transformers/models/hiera/modeling_hiera.py index 0c084f0f836e..7ae70f6cbe8b 100644 --- a/src/transformers/models/hiera/modeling_hiera.py +++ b/src/transformers/models/hiera/modeling_hiera.py @@ -404,11 +404,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 57b3df2f570b..a9ab176d0bab 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -89,8 +89,6 @@ def __init__(self, config): self.embeddings_act1 = QuantAct(self.embedding_act_bit, quant_mode=self.quant_mode) self.embeddings_act2 = QuantAct(self.embedding_act_bit, quant_mode=self.quant_mode) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = IntLayerNorm( config.hidden_size, eps=config.layer_norm_eps, diff --git a/src/transformers/models/imagegpt/configuration_imagegpt.py b/src/transformers/models/imagegpt/configuration_imagegpt.py index 8cfa8d5e4782..435324721d86 100644 --- a/src/transformers/models/imagegpt/configuration_imagegpt.py +++ b/src/transformers/models/imagegpt/configuration_imagegpt.py @@ -16,7 +16,7 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig @@ -24,7 +24,7 @@ if TYPE_CHECKING: - from ... import FeatureExtractionMixin, TensorType + from ... import FeatureExtractionMixin logger = logging.get_logger(__name__) @@ -159,13 +159,12 @@ def generate_dummy_inputs( batch_size: int = 1, seq_length: int = -1, is_pair: bool = False, - framework: Optional["TensorType"] = None, num_channels: int = 3, image_width: int = 32, image_height: int = 32, ) -> Mapping[str, Any]: """ - Generate inputs to provide to the ONNX exporter for the specific framework + Generate inputs to provide to the ONNX exporter. Args: preprocessor ([`PreTrainedTokenizerBase`] or [`FeatureExtractionMixin`]): @@ -178,8 +177,6 @@ def generate_dummy_inputs( The sequence length to export the model for (-1 means dynamic axis). is_pair (`bool`, *optional*, defaults to `False`): Indicate if the input is a pair (sentence 1, sentence 2) - framework (`TensorType`, *optional*, defaults to `None`): - The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for. num_channels (`int`, *optional*, defaults to 3): The number of channels of the generated images. image_width (`int`, *optional*, defaults to 40): @@ -192,7 +189,7 @@ def generate_dummy_inputs( """ input_image = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) - inputs = dict(preprocessor(images=input_image, return_tensors=framework)) + inputs = dict(preprocessor(images=input_image, return_tensors="pt")) return inputs diff --git a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py index 8f9fbd706b32..ad4910dcb8c1 100644 --- a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py +++ b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py @@ -484,7 +484,6 @@ def __init__(self, config): self.is_causal = False self.scaling = self.head_dim**-0.5 - # Mesh TensorFlow initialization to avoid scaling before softmax self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False) self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False) self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False) diff --git a/src/transformers/models/layoutlm/configuration_layoutlm.py b/src/transformers/models/layoutlm/configuration_layoutlm.py index 18bfacb75592..f777cd862408 100644 --- a/src/transformers/models/layoutlm/configuration_layoutlm.py +++ b/src/transformers/models/layoutlm/configuration_layoutlm.py @@ -20,7 +20,7 @@ from ... import PretrainedConfig, PreTrainedTokenizer from ...onnx import OnnxConfig, PatchingSpec -from ...utils import TensorType, is_torch_available, logging +from ...utils import is_torch_available, logging logger = logging.get_logger(__name__) @@ -154,32 +154,30 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: """ - Generate inputs to provide to the ONNX exporter for the specific framework + Generate inputs to provide to the ONNX exporter Args: tokenizer: The tokenizer associated with this model configuration batch_size: The batch size (int) to export the model for (-1 means dynamic axis) seq_length: The sequence length (int) to export the model for (-1 means dynamic axis) is_pair: Indicate if the input is a pair (sentence 1, sentence 2) - framework: The framework (optional) the tokenizer will generate tensor for Returns: Mapping[str, Tensor] holding the kwargs to provide to the model's forward function """ input_dict = super().generate_dummy_inputs( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, + batch_size=batch_size, + seq_length=seq_length, + is_pair=is_pair, ) # Generate a dummy bbox box = [48, 84, 73, 128] - if not framework == TensorType.PYTORCH: - raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.") - if not is_torch_available(): raise ValueError("Cannot generate dummy inputs without PyTorch installed.") import torch diff --git a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py index c845bb43b346..b78760743832 100644 --- a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py @@ -16,7 +16,7 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from packaging import version @@ -28,7 +28,6 @@ if TYPE_CHECKING: from ...processing_utils import ProcessorMixin - from ...utils import TensorType logger = logging.get_logger(__name__) @@ -227,13 +226,12 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional["TensorType"] = None, num_channels: int = 3, image_width: int = 40, image_height: int = 40, ) -> Mapping[str, Any]: """ - Generate inputs to provide to the ONNX exporter for the specific framework + Generate inputs to provide to the ONNX exporter Args: processor ([`ProcessorMixin`]): @@ -244,8 +242,6 @@ def generate_dummy_inputs( The sequence length to export the model for (-1 means dynamic axis). is_pair (`bool`, *optional*, defaults to `False`): Indicate if the input is a pair (sentence 1, sentence 2). - framework (`TensorType`, *optional*, defaults to `None`): - The framework (PyTorch or TensorFlow) that the processor will generate tensors for. num_channels (`int`, *optional*, defaults to 3): The number of channels of the generated images. image_width (`int`, *optional*, defaults to 40): @@ -284,7 +280,7 @@ def generate_dummy_inputs( dummy_image, text=dummy_text, boxes=dummy_bboxes, - return_tensors=framework, + return_tensors="pt", ) ) diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index bb00d16c3965..4ce6b3b28328 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -46,8 +46,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/longformer/configuration_longformer.py b/src/transformers/models/longformer/configuration_longformer.py index 207cc1839479..111ede4d0dd6 100644 --- a/src/transformers/models/longformer/configuration_longformer.py +++ b/src/transformers/models/longformer/configuration_longformer.py @@ -20,7 +20,7 @@ from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig -from ...utils import TensorType, logging +from ...utils import logging if TYPE_CHECKING: @@ -188,10 +188,12 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: inputs = super().generate_dummy_inputs( - preprocessor=tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + preprocessor=tokenizer, + batch_size=batch_size, + seq_length=seq_length, + is_pair=is_pair, ) import torch diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index cdc708924967..c42ace1aae2a 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -392,8 +392,6 @@ def __init__(self, config): self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index e78197beeb57..95e71e8e4a1c 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -298,8 +298,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py index fc95ed11b079..33014f655607 100644 --- a/src/transformers/models/luke/tokenization_luke.py +++ b/src/transformers/models/luke/tokenization_luke.py @@ -1403,7 +1403,7 @@ def pad( Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length in the batch. Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`, `self.pad_token_id` and `self.pad_token_type_id`) .. note:: If the `encoded_inputs` passed - are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the result will use the same type unless + are dictionary of numpy arrays or PyTorch tensors the result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of PyTorch tensors, you will lose the specific device of your tensors however. diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index 5b81ed662e8a..c2183b6f41c8 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -187,7 +187,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/m2m_100/configuration_m2m_100.py b/src/transformers/models/m2m_100/configuration_m2m_100.py index 620641f1cf4e..ff4f6f0d1af8 100644 --- a/src/transformers/models/m2m_100/configuration_m2m_100.py +++ b/src/transformers/models/m2m_100/configuration_m2m_100.py @@ -16,13 +16,13 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast from ...onnx.utils import compute_effective_axis_dimension -from ...utils import TensorType, is_torch_available, logging +from ...utils import is_torch_available, logging logger = logging.get_logger(__name__) @@ -189,7 +189,6 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: # Copied from OnnxConfig.generate_dummy_inputs # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. @@ -206,7 +205,7 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( # Generate dummy inputs according to compute batch and sequence dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size - common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + common_inputs = dict(tokenizer(dummy_input, return_tensors="pt")) return common_inputs # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm @@ -216,16 +215,15 @@ def _generate_dummy_inputs_for_default_and_seq2seq_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size, seq_length, is_pair ) # Generate decoder inputs decoder_seq_length = seq_length if not self.use_past else 1 decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, decoder_seq_length, is_pair, framework + tokenizer, batch_size, decoder_seq_length, is_pair ) decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} common_inputs = dict(**encoder_inputs, **decoder_inputs) diff --git a/src/transformers/models/marian/configuration_marian.py b/src/transformers/models/marian/configuration_marian.py index 0e0468c50b5f..fd68286b9bed 100644 --- a/src/transformers/models/marian/configuration_marian.py +++ b/src/transformers/models/marian/configuration_marian.py @@ -16,13 +16,13 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from ...onnx.utils import compute_effective_axis_dimension -from ...utils import TensorType, is_torch_available, logging +from ...utils import is_torch_available, logging logger = logging.get_logger(__name__) @@ -230,16 +230,21 @@ def _generate_dummy_inputs_for_default_and_seq2seq_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: encoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, + batch_size, + seq_length, + is_pair, ) # Generate decoder inputs decoder_seq_length = seq_length if not self.use_past else 1 decoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder( - tokenizer, batch_size, decoder_seq_length, is_pair, framework + tokenizer, + batch_size, + decoder_seq_length, + is_pair, ) decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} common_inputs = dict(**encoder_inputs, **decoder_inputs) @@ -298,10 +303,12 @@ def _generate_dummy_inputs_for_causal_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = self._generate_dummy_inputs_for_encoder_and_decoder( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, + batch_size, + seq_length, + is_pair, ) if self.use_past: @@ -338,7 +345,6 @@ def _generate_dummy_inputs_for_encoder_and_decoder( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: # Copied from OnnxConfig.generate_dummy_inputs # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. @@ -355,7 +361,7 @@ def _generate_dummy_inputs_for_encoder_and_decoder( # Generate dummy inputs according to compute batch and sequence dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size - common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + common_inputs = dict(tokenizer(dummy_input, return_tensors="pt")) return common_inputs def generate_dummy_inputs( @@ -364,16 +370,21 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: if self.task in ["default", "seq2seq-lm"]: common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, + batch_size=batch_size, + seq_length=seq_length, + is_pair=is_pair, ) else: common_inputs = self._generate_dummy_inputs_for_causal_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, + batch_size=batch_size, + seq_length=seq_length, + is_pair=is_pair, ) return common_inputs diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 3a3e076a5a4c..018c33d377ca 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -107,11 +107,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py index 104e7e00d9e5..ba0dd16553cb 100644 --- a/src/transformers/models/mbart/configuration_mbart.py +++ b/src/transformers/models/mbart/configuration_mbart.py @@ -16,13 +16,13 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from ...onnx.utils import compute_effective_axis_dimension -from ...utils import TensorType, is_torch_available, logging +from ...utils import is_torch_available, logging logger = logging.get_logger(__name__) @@ -229,16 +229,15 @@ def _generate_dummy_inputs_for_default_and_seq2seq_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size, seq_length, is_pair ) # Generate decoder inputs decoder_seq_length = seq_length if not self.use_past else 1 decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, decoder_seq_length, is_pair, framework + tokenizer, batch_size, decoder_seq_length, is_pair ) decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} common_inputs = dict(**encoder_inputs, **decoder_inputs) @@ -297,10 +296,9 @@ def _generate_dummy_inputs_for_causal_lm( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size, seq_length, is_pair ) if self.use_past: @@ -335,7 +333,6 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: # Copied from OnnxConfig.generate_dummy_inputs # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. @@ -352,7 +349,7 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( # Generate dummy inputs according to compute batch and sequence dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size - common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + common_inputs = dict(tokenizer(dummy_input, return_tensors="pt")) return common_inputs def generate_dummy_inputs( @@ -361,20 +358,19 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: if self.task in ["default", "seq2seq-lm"]: common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) elif self.task == "causal-lm": common_inputs = self._generate_dummy_inputs_for_causal_lm( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) else: common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair ) return common_inputs diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index d6a45bbd8eb8..8b8f842c2a2a 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -58,9 +58,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - # In Megatron, layer-norm is applied after the 1st dropout. # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py index be7cf08b14ab..73a963130a90 100644 --- a/src/transformers/models/mgp_str/modeling_mgp_str.py +++ b/src/transformers/models/mgp_str/modeling_mgp_str.py @@ -36,11 +36,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/mluke/tokenization_mluke.py b/src/transformers/models/mluke/tokenization_mluke.py index 3d7a210d816a..7fb1d2e490e0 100644 --- a/src/transformers/models/mluke/tokenization_mluke.py +++ b/src/transformers/models/mluke/tokenization_mluke.py @@ -1241,7 +1241,7 @@ def pad( Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length in the batch. Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`, `self.pad_token_id` and `self.pad_token_type_id`) .. note:: If the `encoded_inputs` passed - are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the result will use the same type unless + are dictionary of numpy arrays or PyTorch tensors the result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of PyTorch tensors, you will lose the specific device of your tensors however. diff --git a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py index c3d498323de4..3202a5e80a02 100644 --- a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +++ b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py @@ -411,11 +411,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 86bee4d09b5a..c5584cb29cb1 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -469,8 +469,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings + 2, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 3eb1fad24019..8b67ab7001c1 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -52,8 +52,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings + 2, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/owlvit/configuration_owlvit.py b/src/transformers/models/owlvit/configuration_owlvit.py index d4873ff4a08b..4f615dece67e 100644 --- a/src/transformers/models/owlvit/configuration_owlvit.py +++ b/src/transformers/models/owlvit/configuration_owlvit.py @@ -16,12 +16,11 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...processing_utils import ProcessorMixin - from ...utils import TensorType from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig @@ -318,13 +317,15 @@ def generate_dummy_inputs( processor: "ProcessorMixin", batch_size: int = -1, seq_length: int = -1, - framework: Optional["TensorType"] = None, ) -> Mapping[str, Any]: text_input_dict = super().generate_dummy_inputs( - processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework + processor.tokenizer, + batch_size=batch_size, + seq_length=seq_length, ) image_input_dict = super().generate_dummy_inputs( - processor.image_processor, batch_size=batch_size, framework=framework + processor.image_processor, + batch_size=batch_size, ) return {**text_input_dict, **image_input_dict} diff --git a/src/transformers/models/perceiver/configuration_perceiver.py b/src/transformers/models/perceiver/configuration_perceiver.py index 91e7bcd58fdc..d983779c6add 100644 --- a/src/transformers/models/perceiver/configuration_perceiver.py +++ b/src/transformers/models/perceiver/configuration_perceiver.py @@ -16,14 +16,14 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import Any, Optional, Union +from typing import Any, Union from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import FeatureExtractionMixin from ...onnx import OnnxConfig from ...onnx.utils import compute_effective_axis_dimension from ...tokenization_utils_base import PreTrainedTokenizerBase -from ...utils import TensorType, logging +from ...utils import logging logger = logging.get_logger(__name__) @@ -207,7 +207,6 @@ def generate_dummy_inputs( seq_length: int = -1, num_choices: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, num_channels: int = 3, image_width: int = 40, image_height: int = 40, @@ -226,14 +225,14 @@ def generate_dummy_inputs( ) # Generate dummy inputs according to compute batch and sequence dummy_input = [" ".join(["a"]) * seq_length] * batch_size - inputs = dict(preprocessor(dummy_input, return_tensors=framework)) + inputs = dict(preprocessor(dummy_input, return_tensors="pt")) inputs["inputs"] = inputs.pop("input_ids") return inputs elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values": # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) - inputs = dict(preprocessor(images=dummy_input, return_tensors=framework)) + inputs = dict(preprocessor(images=dummy_input, return_tensors="pt")) inputs["inputs"] = inputs.pop("pixel_values") return inputs else: diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 8c6dc8191630..b7ee51991f94 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -35,11 +35,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index 21af67542d70..e77f0d5d748a 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -41,11 +41,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/pvt_v2/modeling_pvt_v2.py b/src/transformers/models/pvt_v2/modeling_pvt_v2.py index 204198787e45..a5c2f1e97d8d 100644 --- a/src/transformers/models/pvt_v2/modeling_pvt_v2.py +++ b/src/transformers/models/pvt_v2/modeling_pvt_v2.py @@ -40,11 +40,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 26c3c693245f..3a1980885339 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -56,8 +56,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.input_embedding_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.input_embedding_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.input_embedding_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 4ba12384fbed..2865460718c2 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -59,7 +59,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index e51214beac7c..5247c39b7553 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -59,7 +59,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index aba89b1c309d..0b91af94bbaa 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -75,8 +75,6 @@ def __init__(self, config): else: self.map_inputs_layer = None - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 26eca95b456a..03a2195da287 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -88,7 +88,6 @@ def __init__(self, config): self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index 8a81f68beadd..a60eff60c65b 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -64,11 +64,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/seggpt/modeling_seggpt.py b/src/transformers/models/seggpt/modeling_seggpt.py index 7e82d26c9e74..1fde52bae079 100644 --- a/src/transformers/models/seggpt/modeling_seggpt.py +++ b/src/transformers/models/seggpt/modeling_seggpt.py @@ -369,11 +369,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 116a17330923..3d1f6138b00e 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -46,8 +46,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/squeezebert/modeling_squeezebert.py b/src/transformers/models/squeezebert/modeling_squeezebert.py index a732e76bda6a..7e58dacd7f84 100644 --- a/src/transformers/models/squeezebert/modeling_squeezebert.py +++ b/src/transformers/models/squeezebert/modeling_squeezebert.py @@ -51,8 +51,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py index 95114e3d332c..7ecd94a8fd52 100644 --- a/src/transformers/models/swiftformer/modeling_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -62,11 +62,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 18b61abbd3a4..e6cb0afc5781 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -365,11 +365,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index e010a1d8a01e..83dfe13baded 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -75,11 +75,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 1463f0f82e7e..424902ebdee3 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -182,11 +182,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index dc2b722789f2..2987e7ec7467 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -85,8 +85,6 @@ def __init__(self, config): self.number_of_token_type_embeddings = len(config.type_vocab_sizes) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 0aa06d5c33bb..1c125dac4f32 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -151,11 +151,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 75e58f9858fd..09ee7f07357d 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -229,8 +229,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py index 248bf73ff9fa..a069a888f02f 100644 --- a/src/transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py @@ -16,7 +16,7 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from packaging import version @@ -27,7 +27,7 @@ if TYPE_CHECKING: - from ... import PreTrainedTokenizerBase, TensorType + from ... import PreTrainedTokenizerBase logger = logging.get_logger(__name__) @@ -154,14 +154,16 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional["TensorType"] = None, ) -> Mapping[str, Any]: import torch common_inputs = OrderedDict() dummy_input = super().generate_dummy_inputs( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, + batch_size=batch_size, + seq_length=seq_length, + is_pair=is_pair, ) batch, encoder_sequence = dummy_input["input_ids"].shape diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index cdc3e3adc69b..ecb1a6cd1594 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -48,9 +48,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index 8debcaf11fa5..f6702bc1a124 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -266,11 +266,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py index eedc94b845a4..714c3d92d827 100644 --- a/src/transformers/models/vjepa2/modeling_vjepa2.py +++ b/src/transformers/models/vjepa2/modeling_vjepa2.py @@ -350,11 +350,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/whisper/configuration_whisper.py b/src/transformers/models/whisper/configuration_whisper.py index a8837cb1cdd2..1950e03f54e2 100644 --- a/src/transformers/models/whisper/configuration_whisper.py +++ b/src/transformers/models/whisper/configuration_whisper.py @@ -16,7 +16,7 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast @@ -26,7 +26,6 @@ if TYPE_CHECKING: from ...feature_extraction_utils import FeatureExtractionMixin from ...tokenization_utils_base import PreTrainedTokenizerBase - from ...utils import TensorType logger = logging.get_logger(__name__) @@ -310,7 +309,6 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional["TensorType"] = None, sampling_rate: int = 22050, time_duration: float = 5.0, frequency: int = 220, @@ -320,7 +318,6 @@ def generate_dummy_inputs( self, preprocessor=preprocessor.feature_extractor, batch_size=batch_size, - framework=framework, sampling_rate=sampling_rate, time_duration=time_duration, frequency=frequency, @@ -329,7 +326,10 @@ def generate_dummy_inputs( seq_length = encoder_sequence_length // 2 if self.use_past else seq_length decoder_inputs = super().generate_dummy_inputs( - preprocessor.tokenizer, batch_size, seq_length, is_pair, framework + preprocessor.tokenizer, + batch_size, + seq_length, + is_pair, ) dummy_inputs["input_features"] = encoder_inputs.pop("input_features") diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 403b9a408162..3c906a85392e 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -390,11 +390,6 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. """ if drop_prob == 0.0 or not training: return input diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 4c42d7b88615..4e0fab16b429 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -60,7 +60,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 658083839b23..c07f9e9bf760 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -57,8 +57,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index c2f25d9773e9..f7242a64d5d4 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -58,7 +58,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 0ad53b81f492..80c0dfd6539e 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -237,8 +237,6 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings + 2, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index 2a47127b3855..46c9d32b7341 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -22,7 +22,7 @@ import numpy as np from packaging import version -from ..utils import TensorType, is_torch_available, is_vision_available, logging +from ..utils import is_torch_available, is_vision_available, logging from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size @@ -287,7 +287,6 @@ def generate_dummy_inputs( seq_length: int = -1, num_choices: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, num_channels: int = 3, image_width: int = 40, image_height: int = 40, @@ -297,7 +296,7 @@ def generate_dummy_inputs( tokenizer: Optional["PreTrainedTokenizerBase"] = None, ) -> Mapping[str, Any]: """ - Generate inputs to provide to the ONNX exporter for the specific framework + Generate inputs to provide to the ONNX exporter Args: preprocessor: ([`PreTrainedTokenizerBase`], [`FeatureExtractionMixin`], or [`ImageProcessingMixin`]): @@ -310,8 +309,6 @@ def generate_dummy_inputs( The sequence length to export the model for (-1 means dynamic axis). is_pair (`bool`, *optional*, defaults to `False`): Indicate if the input is a pair (sentence 1, sentence 2) - framework (`TensorType`, *optional*, defaults to `None`): - The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for. num_channels (`int`, *optional*, defaults to 3): The number of channels of the generated images. image_width (`int`, *optional*, defaults to 40): @@ -371,8 +368,8 @@ def generate_dummy_inputs( # Unflatten the tokenized inputs values expanding it to the shape [batch_size, num_choices, seq_length] for k, v in tokenized_input.items(): tokenized_input[k] = [v[i : i + num_choices] for i in range(0, len(v), num_choices)] - return dict(tokenized_input.convert_to_tensors(tensor_type=framework)) - return dict(preprocessor(dummy_input, return_tensors=framework)) + return dict(tokenized_input.convert_to_tensors(tensor_type="pt")) + return dict(preprocessor(dummy_input, return_tensors="pt")) elif isinstance(preprocessor, ImageProcessingMixin): if preprocessor.model_input_names[0] != "pixel_values": raise ValueError( @@ -382,19 +379,19 @@ def generate_dummy_inputs( # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) - return dict(preprocessor(images=dummy_input, return_tensors=framework)) + return dict(preprocessor(images=dummy_input, return_tensors="pt")) elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values": # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) - return dict(preprocessor(images=dummy_input, return_tensors=framework)) + return dict(preprocessor(images=dummy_input, return_tensors="pt")) elif ( isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "input_features" ): # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) dummy_input = self._generate_dummy_audio(batch_size, sampling_rate, time_duration, frequency) - return dict(preprocessor(dummy_input, return_tensors=framework)) + return dict(preprocessor(dummy_input, return_tensors="pt")) else: raise ValueError( "Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor." @@ -514,11 +511,13 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: # TODO: should we set seq_length = 1 when self.use_past = True? common_inputs = super().generate_dummy_inputs( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, + batch_size=batch_size, + seq_length=seq_length, + is_pair=is_pair, ) if self.use_past: @@ -646,16 +645,21 @@ def generate_dummy_inputs( batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, - framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + tokenizer, + batch_size=batch_size, + seq_length=seq_length, + is_pair=is_pair, ) # Generate decoder inputs decoder_seq_length = seq_length if not self.use_past else 1 decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size=batch_size, seq_length=decoder_seq_length, is_pair=is_pair, framework=framework + tokenizer, + batch_size=batch_size, + seq_length=decoder_seq_length, + is_pair=is_pair, ) decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} common_inputs = dict(**encoder_inputs, **decoder_inputs) diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index 8bac1b7cb235..3e7360998334 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -24,7 +24,6 @@ from ..tokenization_utils_base import PreTrainedTokenizerBase from ..utils import ( - TensorType, is_torch_available, logging, ) @@ -136,7 +135,7 @@ def export_pytorch( # Ensure inputs match # TODO: Check when exporting QA we provide "is_pair=True" - model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH) + model_inputs = config.generate_dummy_inputs(preprocessor) device = torch.device(device) if device.type == "cuda" and torch.cuda.is_available(): model.to(device) @@ -267,7 +266,6 @@ def validate_model_outputs( preprocessor, batch_size=config.default_fixed_batch + 1, seq_length=config.default_fixed_sequence + 1, - framework=TensorType.PYTORCH, ) # Create ONNX Runtime session diff --git a/tests/utils/test_auto_docstring.py b/tests/utils/test_auto_docstring.py index 1874631f5e87..4dc3ec1efb27 100644 --- a/tests/utils/test_auto_docstring.py +++ b/tests/utils/test_auto_docstring.py @@ -26,13 +26,13 @@ GEMMA3_IMAGE_PROCESSOR_FAST_DOCSTRING = """\nConstructs a fast Gemma3 image processor.\n\nParameters:\n do_resize (`Optional[bool]`, defaults to `True`):\n Whether to resize the image.\n size (`Optional[dict[str, int]]`, defaults to `{\'height\': 224, \'width\': 224}`):\n Describes the maximum input dimensions to the model.\n default_to_square (`Optional[bool]`, defaults to `True`):\n Whether to default to a square image when resizing, if size is an int.\n resample (`Union[PILImageResampling, F.InterpolationMode, NoneType]`, defaults to `2`):\n Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only\n has an effect if `do_resize` is set to `True`.\n do_center_crop (`Optional[bool]`, defaults to `None`):\n Whether to center crop the image.\n crop_size (`Optional[dict[str, int]]`, defaults to `None`):\n Size of the output image after applying `center_crop`.\n do_rescale (`Optional[bool]`, defaults to `True`):\n Whether to rescale the image.\n rescale_factor (`Union[int, float, NoneType]`, defaults to `0.00392156862745098`):\n Rescale factor to rescale the image by if `do_rescale` is set to `True`.\n do_normalize (`Optional[bool]`, defaults to `True`):\n Whether to normalize the image.\n image_mean (`Union[float, list[float], NoneType]`, defaults to `[0.5, 0.5, 0.5]`):\n Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.\n image_std (`Union[float, list[float], NoneType]`, defaults to `[0.5, 0.5, 0.5]`):\n Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to\n `True`.\n do_convert_rgb (`Optional[bool]`, defaults to `None`):\n Whether to convert the image to RGB.\n return_tensors (`Union[str, ~utils.generic.TensorType, NoneType]`, defaults to `None`):\n Returns stacked tensors if set to `pt, otherwise returns a list of tensors.\n data_format (`Optional[~image_utils.ChannelDimension]`, defaults to `ChannelDimension.FIRST`):\n Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.\n input_data_format (`Union[str, ~image_utils.ChannelDimension, NoneType]`, defaults to `None`):\n The channel dimension format for the input image. If unset, the channel dimension format is inferred\n from the input image. Can be one of:\n - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.\n device (`Optional[torch.device]`, defaults to `None`):\n The device to process the images on. If unset, the device is inferred from the input images.\n do_pan_and_scan (`Optional[bool]`, defaults to `None`):\n Whether to apply `pan_and_scan` to images.\n pan_and_scan_min_crop_size (`Optional[int]`, defaults to `None`):\n Minimum size of each crop in pan and scan.\n pan_and_scan_max_num_crops (`Optional[int]`, defaults to `None`):\n Maximum number of crops per image in pan and scan.\n pan_and_scan_min_ratio_to_activate (`Optional[float]`, defaults to `None`):\n Minimum aspect ratio to activate pan and scan.\n""" -GEMMA3_IMAGE_PROCESSOR_FAST_PREPROCESS_DOCSTRING = """ Args:\n images (`Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, list[\'PIL.Image.Image\'], list[numpy.ndarray], list[\'torch.Tensor\']]`):\n Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If\n passing in images with pixel values between 0 and 1, set `do_rescale=False`.\n do_resize (`Optional[bool]`):\n Whether to resize the image.\n size (`Optional[dict[str, int]]`):\n Describes the maximum input dimensions to the model.\n default_to_square (`Optional[bool]`):\n Whether to default to a square image when resizing, if size is an int.\n resample (`Union[PILImageResampling, F.InterpolationMode, NoneType]`):\n Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only\n has an effect if `do_resize` is set to `True`.\n do_center_crop (`Optional[bool]`):\n Whether to center crop the image.\n crop_size (`Optional[dict[str, int]]`):\n Size of the output image after applying `center_crop`.\n do_rescale (`Optional[bool]`):\n Whether to rescale the image.\n rescale_factor (`Union[int, float, NoneType]`):\n Rescale factor to rescale the image by if `do_rescale` is set to `True`.\n do_normalize (`Optional[bool]`):\n Whether to normalize the image.\n image_mean (`Union[float, list[float], NoneType]`):\n Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.\n image_std (`Union[float, list[float], NoneType]`):\n Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to\n `True`.\n do_convert_rgb (`Optional[bool]`):\n Whether to convert the image to RGB.\n return_tensors (`Union[str, ~utils.generic.TensorType, NoneType]`):\n Returns stacked tensors if set to `pt, otherwise returns a list of tensors.\n data_format (`Optional[~image_utils.ChannelDimension]`):\n Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.\n input_data_format (`Union[str, ~image_utils.ChannelDimension, NoneType]`):\n The channel dimension format for the input image. If unset, the channel dimension format is inferred\n from the input image. Can be one of:\n - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.\n device (`Optional[torch.device]`):\n The device to process the images on. If unset, the device is inferred from the input images.\n do_pan_and_scan (`Optional[bool]`):\n Whether to apply `pan_and_scan` to images.\n pan_and_scan_min_crop_size (`Optional[int]`):\n Minimum size of each crop in pan and scan.\n pan_and_scan_max_num_crops (`Optional[int]`):\n Maximum number of crops per image in pan and scan.\n pan_and_scan_min_ratio_to_activate (`Optional[float]`):\n Minimum aspect ratio to activate pan and scan.\n\n Returns:\n ``:\n - **data** (`dict`) -- Dictionary of lists/arrays/tensors returned by the __call__ method (\'pixel_values\', etc.).\n - **tensor_type** (`Union[None, str, TensorType]`, *optional*) -- You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at\n initialization.\n""" +GEMMA3_IMAGE_PROCESSOR_FAST_PREPROCESS_DOCSTRING = """ Args:\n images (`Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, list[\'PIL.Image.Image\'], list[numpy.ndarray], list[\'torch.Tensor\']]`):\n Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If\n passing in images with pixel values between 0 and 1, set `do_rescale=False`.\n do_resize (`Optional[bool]`):\n Whether to resize the image.\n size (`Optional[dict[str, int]]`):\n Describes the maximum input dimensions to the model.\n default_to_square (`Optional[bool]`):\n Whether to default to a square image when resizing, if size is an int.\n resample (`Union[PILImageResampling, F.InterpolationMode, NoneType]`):\n Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only\n has an effect if `do_resize` is set to `True`.\n do_center_crop (`Optional[bool]`):\n Whether to center crop the image.\n crop_size (`Optional[dict[str, int]]`):\n Size of the output image after applying `center_crop`.\n do_rescale (`Optional[bool]`):\n Whether to rescale the image.\n rescale_factor (`Union[int, float, NoneType]`):\n Rescale factor to rescale the image by if `do_rescale` is set to `True`.\n do_normalize (`Optional[bool]`):\n Whether to normalize the image.\n image_mean (`Union[float, list[float], NoneType]`):\n Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.\n image_std (`Union[float, list[float], NoneType]`):\n Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to\n `True`.\n do_convert_rgb (`Optional[bool]`):\n Whether to convert the image to RGB.\n return_tensors (`Union[str, ~utils.generic.TensorType, NoneType]`):\n Returns stacked tensors if set to `pt, otherwise returns a list of tensors.\n data_format (`Optional[~image_utils.ChannelDimension]`):\n Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.\n input_data_format (`Union[str, ~image_utils.ChannelDimension, NoneType]`):\n The channel dimension format for the input image. If unset, the channel dimension format is inferred\n from the input image. Can be one of:\n - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.\n device (`Optional[torch.device]`):\n The device to process the images on. If unset, the device is inferred from the input images.\n do_pan_and_scan (`Optional[bool]`):\n Whether to apply `pan_and_scan` to images.\n pan_and_scan_min_crop_size (`Optional[int]`):\n Minimum size of each crop in pan and scan.\n pan_and_scan_max_num_crops (`Optional[int]`):\n Maximum number of crops per image in pan and scan.\n pan_and_scan_min_ratio_to_activate (`Optional[float]`):\n Minimum aspect ratio to activate pan and scan.\n\n Returns:\n ``:\n - **data** (`dict`) -- Dictionary of lists/arrays/tensors returned by the __call__ method (\'pixel_values\', etc.).\n - **tensor_type** (`Union[None, str, TensorType]`, *optional*) -- You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at\n initialization.\n""" class AutoDocstringTest(unittest.TestCase): pass # def test_modeling_docstring(self): - # llama_docstring = " Args:\n images (`Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']]`):\n Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If\n passing in images with pixel values between 0 and 1, set `do_rescale=False`.\n do_resize (`Optional[bool]`):\n Whether to resize the image.\n size (`Optional[dict[str, int]]`):\n Describes the maximum input dimensions to the model.\n default_to_square (`Optional[bool]`):\n Whether to default to a square image when resizing, if size is an int.\n resample (`Union[PILImageResampling, F.InterpolationMode, NoneType]`):\n Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only\n has an effect if `do_resize` is set to `True`.\n do_center_crop (`Optional[bool]`):\n Whether to center crop the image.\n crop_size (`Optional[dict[str, int]]`):\n Size of the output image after applying `center_crop`.\n do_rescale (`Optional[bool]`):\n Whether to rescale the image.\n rescale_factor (`Union[int, float, NoneType]`):\n Rescale factor to rescale the image by if `do_rescale` is set to `True`.\n do_normalize (`Optional[bool]`):\n Whether to normalize the image.\n image_mean (`Union[float, list[float], NoneType]`):\n Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.\n image_std (`Union[float, list[float], NoneType]`):\n Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to\n `True`.\n do_convert_rgb (`Optional[bool]`):\n Whether to convert the image to RGB.\n return_tensors (`Union[str, ~utils.generic.TensorType, NoneType]`):\n Returns stacked tensors if set to `pt, otherwise returns a list of tensors.\n data_format (`Optional[~image_utils.ChannelDimension]`):\n Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.\n input_data_format (`Union[str, ~image_utils.ChannelDimension, NoneType]`):\n The channel dimension format for the input image. If unset, the channel dimension format is inferred\n from the input image. Can be one of:\n - `\"channels_first\"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n - `\"channels_last\"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n - `\"none\"` or `ChannelDimension.NONE`: image in (height, width) format.\n device (`Optional[torch.device]`):\n The device to process the images on. If unset, the device is inferred from the input images.\n do_pan_and_scan (`Optional[bool]`):\n Whether to apply `pan_and_scan` to images.\n pan_and_scan_min_crop_size (`Optional[int]`):\n Minimum size of each crop in pan and scan.\n pan_and_scan_max_num_crops (`Optional[int]`):\n Maximum number of crops per image in pan and scan.\n pan_and_scan_min_ratio_to_activate (`Optional[float]`):\n Minimum aspect ratio to activate pan and scan.\n\n Returns:\n ``:\n - **data** (`dict`) -- Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).\n - **tensor_type** (`Union[None, str, TensorType]`, *optional*) -- You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at\n initialization.\n" + # llama_docstring = " Args:\n images (`Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']]`):\n Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If\n passing in images with pixel values between 0 and 1, set `do_rescale=False`.\n do_resize (`Optional[bool]`):\n Whether to resize the image.\n size (`Optional[dict[str, int]]`):\n Describes the maximum input dimensions to the model.\n default_to_square (`Optional[bool]`):\n Whether to default to a square image when resizing, if size is an int.\n resample (`Union[PILImageResampling, F.InterpolationMode, NoneType]`):\n Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only\n has an effect if `do_resize` is set to `True`.\n do_center_crop (`Optional[bool]`):\n Whether to center crop the image.\n crop_size (`Optional[dict[str, int]]`):\n Size of the output image after applying `center_crop`.\n do_rescale (`Optional[bool]`):\n Whether to rescale the image.\n rescale_factor (`Union[int, float, NoneType]`):\n Rescale factor to rescale the image by if `do_rescale` is set to `True`.\n do_normalize (`Optional[bool]`):\n Whether to normalize the image.\n image_mean (`Union[float, list[float], NoneType]`):\n Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.\n image_std (`Union[float, list[float], NoneType]`):\n Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to\n `True`.\n do_convert_rgb (`Optional[bool]`):\n Whether to convert the image to RGB.\n return_tensors (`Union[str, ~utils.generic.TensorType, NoneType]`):\n Returns stacked tensors if set to `pt, otherwise returns a list of tensors.\n data_format (`Optional[~image_utils.ChannelDimension]`):\n Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.\n input_data_format (`Union[str, ~image_utils.ChannelDimension, NoneType]`):\n The channel dimension format for the input image. If unset, the channel dimension format is inferred\n from the input image. Can be one of:\n - `\"channels_first\"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n - `\"channels_last\"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n - `\"none\"` or `ChannelDimension.NONE`: image in (height, width) format.\n device (`Optional[torch.device]`):\n The device to process the images on. If unset, the device is inferred from the input images.\n do_pan_and_scan (`Optional[bool]`):\n Whether to apply `pan_and_scan` to images.\n pan_and_scan_min_crop_size (`Optional[int]`):\n Minimum size of each crop in pan and scan.\n pan_and_scan_max_num_crops (`Optional[int]`):\n Maximum number of crops per image in pan and scan.\n pan_and_scan_min_ratio_to_activate (`Optional[float]`):\n Minimum aspect ratio to activate pan and scan.\n\n Returns:\n ``:\n - **data** (`dict`) -- Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).\n - **tensor_type** (`Union[None, str, TensorType]`, *optional*) -- You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at\n initialization.\n" # self.assertEqual(llama_docstring, LlamaModel.__doc__) # self.assertEqual(LLAMA_MODEL_DOCSTRING, LlamaModel.forward.__doc__) From 0e9fd50beabafbdcbfddefc5e5bb80750d4be5e9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 16:46:11 +0200 Subject: [PATCH 25/35] still on the grind --- src/transformers/models/longt5/modeling_longt5.py | 3 --- src/transformers/models/luke/tokenization_luke.py | 4 ++-- src/transformers/models/mllama/processing_mllama.py | 2 -- src/transformers/models/mluke/tokenization_mluke.py | 4 ++-- src/transformers/models/mt5/modeling_mt5.py | 1 - src/transformers/models/openai/modeling_openai.py | 4 ++-- .../models/perceiver/modeling_perceiver.py | 2 -- .../models/pix2struct/modeling_pix2struct.py | 2 -- .../models/pop2piano/modeling_pop2piano.py | 1 - .../modeling_switch_transformers.py | 1 - src/transformers/models/t5/modeling_t5.py | 1 - src/transformers/models/tapas/tokenization_tapas.py | 2 +- src/transformers/models/udop/modeling_udop.py | 1 - src/transformers/models/umt5/modeling_umt5.py | 1 - src/transformers/onnx/convert.py | 2 +- src/transformers/pipelines/base.py | 2 +- tests/models/tapas/test_modeling_tapas.py | 12 ------------ utils/check_repo.py | 4 ---- 18 files changed, 9 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 70eec28d89f4..a3499fb2a0ba 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -347,7 +347,6 @@ def __init__( "when creating this class." ) - # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -570,7 +569,6 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = Fal self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim - # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -765,7 +763,6 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = Fal self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim - # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py index 33014f655607..4bb19bb5ee73 100644 --- a/src/transformers/models/luke/tokenization_luke.py +++ b/src/transformers/models/luke/tokenization_luke.py @@ -1412,8 +1412,8 @@ def pad( Tokenized inputs. Can represent one input ([`BatchEncoding`] or `dict[str, list[int]]`) or a batch of tokenized inputs (list of [`BatchEncoding`], *dict[str, list[list[int]]]* or *list[dict[str, list[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader - collate function. Instead of `list[int]` you can have tensors (numpy arrays, PyTorch tensors or - TensorFlow tensors), see the note above for the return type. + collate function. Instead of `list[int]` you can have tensors (numpy arrays, or PyTorch tensors), + see the note above for the return type. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: diff --git a/src/transformers/models/mllama/processing_mllama.py b/src/transformers/models/mllama/processing_mllama.py index 0dae7c834303..a5a0ae8739b3 100644 --- a/src/transformers/models/mllama/processing_mllama.py +++ b/src/transformers/models/mllama/processing_mllama.py @@ -246,10 +246,8 @@ def __call__( `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/mluke/tokenization_mluke.py b/src/transformers/models/mluke/tokenization_mluke.py index 7fb1d2e490e0..d63129c7b7e4 100644 --- a/src/transformers/models/mluke/tokenization_mluke.py +++ b/src/transformers/models/mluke/tokenization_mluke.py @@ -1250,8 +1250,8 @@ def pad( Tokenized inputs. Can represent one input ([`BatchEncoding`] or `dict[str, list[int]]`) or a batch of tokenized inputs (list of [`BatchEncoding`], *dict[str, list[list[int]]]* or *list[dict[str, list[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader - collate function. Instead of `list[int]` you can have tensors (numpy arrays, PyTorch tensors or - TensorFlow tensors), see the note above for the return type. + collate function. Instead of `list[int]` you can have tensors (numpy arrays, or PyTorch tensors), + see the note above for the return type. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 451db734a3a5..50d514be4a14 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -244,7 +244,6 @@ def __init__( "when creating this class." ) - # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index d651f04a1bd1..b9de7b999a09 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -107,11 +107,11 @@ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions= def merge_heads(self, x): x = x.permute(0, 2, 1, 3).contiguous() new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) - return x.view(*new_x_shape) # in Tensorflow implementation: fct merge_states + return x.view(*new_x_shape) def split_heads(self, x, k=False): new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) - x = x.view(*new_x_shape) # in Tensorflow implementation: fct split_states + x = x.view(*new_x_shape) if k: return x.permute(0, 2, 3, 1) else: diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index f0e4e3e5dbe0..ef0786074f9f 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -2379,8 +2379,6 @@ def space_to_depth(frames: torch.Tensor, temporal_block_size: int = 1, spatial_b Space to depth transform. Rearranges blocks of spatial data, into depth. This function assumes the channels to be first, but will place the channels last after transformation. - - Based on https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/15. """ if len(frames.shape) == 4: batch_size, num_channels, height, width = frames.shape diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index c79e937fc28a..049cb7f0bd43 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -142,7 +142,6 @@ def __init__(self, config): self.dropout = config.attention_dropout self.inner_dim = self.n_heads * self.key_value_proj_dim - # Mesh TensorFlow initialization to avoid scaling before softmax self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False) self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False) self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False) @@ -649,7 +648,6 @@ def __init__( "when creating this class." ) - # Mesh TensorFlow initialization to avoid scaling before softmax self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.key = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.value = nn.Linear(self.hidden_size, self.hidden_size, bias=False) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 74d9ebc97dc2..83bc759e72cf 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -189,7 +189,6 @@ def __init__( "when creating this class." ) - # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index a2e3f8fb4c10..761f1c1ccc8f 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -382,7 +382,6 @@ def __init__( "when creating this class." ) - # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 7cd5e6394afd..f74569574c8f 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -249,7 +249,6 @@ def __init__( "when creating this class." ) - # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) diff --git a/src/transformers/models/tapas/tokenization_tapas.py b/src/transformers/models/tapas/tokenization_tapas.py index 44558f5542e4..098bee1dd7f5 100644 --- a/src/transformers/models/tapas/tokenization_tapas.py +++ b/src/transformers/models/tapas/tokenization_tapas.py @@ -2206,7 +2206,7 @@ def tokenize(self, text): return output_tokens -# Below: utilities for TAPAS tokenizer (independent from PyTorch/Tensorflow). +# Below: utilities for TAPAS tokenizer # This includes functions to parse numeric values (dates and numbers) from both the table and questions in order # to create the column_ranks, inv_column_ranks, numeric_values, numeric values_scale and numeric_relations in # prepare_for_model of TapasTokenizer. diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index deacc075c72b..22f45731030e 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -460,7 +460,6 @@ def __init__( "when creating this class." ) - # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index a8c592b727c6..6d7437b5a6e5 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -180,7 +180,6 @@ def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optiona "when creating this class." ) - # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index 3e7360998334..f76f08b23db5 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -188,7 +188,7 @@ def export( device: str = "cpu", ) -> tuple[list[str], list[str]]: """ - Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR) + Export a Pytorch model to an ONNX Intermediate Representation (IR) Args: preprocessor: ([`PreTrainedTokenizer`], [`FeatureExtractionMixin`] or [`ProcessorMixin`]): diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 71450199580c..db7a0d77da64 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -973,7 +973,7 @@ def save_pretrained( save_directory (`str` or `os.PathLike`): A path to the directory where to saved. It will be created if it doesn't exist. safe_serialization (`str`): - Whether to save the model using `safetensors` or the traditional way for PyTorch or Tensorflow. + Whether to save the model using `safetensors` or PyTorch serialization. kwargs (`dict[str, Any]`, *optional*): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ diff --git a/tests/models/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py index c96c7691687c..9ffd1ca6a80f 100644 --- a/tests/models/tapas/test_modeling_tapas.py +++ b/tests/models/tapas/test_modeling_tapas.py @@ -972,11 +972,9 @@ def test_product_index(self): self.assertEqual(cell_index.num_segments, 9) # Projections should give back the original indices. - # we use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(row_index.indices.numpy(), row_index_proj.indices.numpy()) self.assertEqual(row_index.num_segments, row_index_proj.num_segments) self.assertEqual(row_index.batch_dims, row_index_proj.batch_dims) - # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(col_index.indices.numpy(), col_index_proj.indices.numpy()) self.assertEqual(col_index.batch_dims, col_index_proj.batch_dims) @@ -1006,7 +1004,6 @@ def test_flatten(self): batched_index = IndexMap(indices=torch.zeros(shape).type(torch.LongTensor), num_segments=1, batch_dims=3) batched_index_flat = flatten(batched_index) - # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal( row_index_flat.indices.numpy(), [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5] ) @@ -1024,11 +1021,9 @@ def test_range_index_map(self): self.assertEqual(num_segments, index.num_segments) self.assertEqual(2, index.batch_dims) indices = index.indices - # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(list(indices.size()), [3, 4, 5]) for i in range(batch_shape[0]): for j in range(batch_shape[1]): - # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(indices[i, j, :].numpy(), range(num_segments)) def test_reduce_sum(self): @@ -1038,7 +1033,6 @@ def test_reduce_sum(self): col_sum, _ = reduce_sum(values, col_index) cell_sum, _ = reduce_sum(values, cell_index) - # We use np.testing.assert_allclose rather than Tensorflow's assertAllClose np.testing.assert_allclose(row_sum.numpy(), [[6.0, 3.0, 8.0], [6.0, 3.0, 8.0]]) np.testing.assert_allclose(col_sum.numpy(), [[9.0, 8.0, 0.0], [4.0, 5.0, 8.0]]) np.testing.assert_allclose( @@ -1053,7 +1047,6 @@ def test_reduce_mean(self): col_mean, _ = reduce_mean(values, col_index) cell_mean, _ = reduce_mean(values, cell_index) - # We use np.testing.assert_allclose rather than Tensorflow's assertAllClose np.testing.assert_allclose( row_mean.numpy(), [[6.0 / 3.0, 3.0 / 3.0, 8.0 / 3.0], [6.0 / 3.0, 3.0 / 3.0, 8.0 / 3.0]] ) @@ -1071,7 +1064,6 @@ def test_reduce_max(self): index = IndexMap(indices=torch.as_tensor([0, 1, 0, 1]), num_segments=2) maximum, _ = reduce_max(values, index) - # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(maximum.numpy(), [2, 3]) def test_reduce_sum_vectorized(self): @@ -1079,9 +1071,7 @@ def test_reduce_sum_vectorized(self): index = IndexMap(indices=torch.as_tensor([[0, 0, 1]]), num_segments=2, batch_dims=0) sums, new_index = reduce_sum(values, index) - # We use np.testing.assert_allclose rather than Tensorflow's assertAllClose np.testing.assert_allclose(sums.numpy(), [3.0, 3.0]) - # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(new_index.indices.numpy(), [0, 1]) np.testing.assert_array_equal(new_index.num_segments.numpy(), 2) np.testing.assert_array_equal(new_index.batch_dims, 0) @@ -1097,7 +1087,6 @@ def test_gather(self): cell_sum = gather(sums, cell_index) assert cell_sum.size() == values.size() - # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_allclose( cell_sum.numpy(), [[[3.0, 3.0, 3.0], [2.0, 2.0, 1.0], [4.0, 4.0, 4.0]], [[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]]], @@ -1108,5 +1097,4 @@ def test_gather_vectorized(self): index = IndexMap(indices=torch.as_tensor([[0, 1], [1, 0]]), num_segments=2, batch_dims=1) result = gather(values, index) - # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(result.numpy(), [[[1, 2], [3, 4]], [[7, 8], [5, 6]]]) diff --git a/utils/check_repo.py b/utils/check_repo.py index 6f95f2662d55..e92ac6f3b7eb 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -882,14 +882,12 @@ def find_all_documented_objects() -> list[str]: # One good reason for not being documented is to be deprecated. Put in this list deprecated objects. DEPRECATED_OBJECTS = [ - "AdamWeightDecay", # TensorFlow object, support is deprecated "AutoModelWithLMHead", "BartPretrainedModel", "DataCollator", "DataCollatorForSOP", "GlueDataset", "GlueDataTrainingArguments", - "GradientAccumulator", # TensorFlow object, support is deprecated "LineByLineTextDataset", "LineByLineWithRefDataset", "LineByLineWithSOPTextDataset", @@ -906,10 +904,8 @@ def find_all_documented_objects() -> list[str]: "SquadV2Processor", "TextDataset", "TextDatasetForNextSentencePrediction", - "WarmUp", # TensorFlow object, support is deprecated "Wav2Vec2ForMaskedLM", "Wav2Vec2Tokenizer", - "create_optimizer", # TensorFlow object, support is deprecated "glue_compute_metrics", "glue_convert_examples_to_features", "glue_output_modes", From ad8cfecf9715bd2dc74bf440b415e758896e5940 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 18:48:24 +0200 Subject: [PATCH 26/35] always more references --- examples/modular-transformers/modeling_test_detr.py | 2 -- examples/pytorch/question-answering/utils_qa.py | 4 ++-- src/transformers/configuration_utils.py | 4 ---- src/transformers/integrations/integration_utils.py | 2 +- src/transformers/models/bark/modeling_bark.py | 2 -- src/transformers/models/beit/modeling_beit.py | 2 -- src/transformers/models/blip/modeling_blip_text.py | 2 -- src/transformers/models/bloom/modeling_bloom.py | 2 -- src/transformers/models/bros/modeling_bros.py | 2 -- .../conditional_detr/modeling_conditional_detr.py | 2 -- src/transformers/models/convnext/modeling_convnext.py | 2 -- .../models/convnextv2/modeling_convnextv2.py | 2 -- src/transformers/models/ctrl/modeling_ctrl.py | 2 -- src/transformers/models/dab_detr/modeling_dab_detr.py | 2 -- .../models/data2vec/modeling_data2vec_text.py | 2 -- .../models/data2vec/modeling_data2vec_vision.py | 2 -- src/transformers/models/deberta/modeling_deberta.py | 4 +--- .../models/deberta_v2/modeling_deberta_v2.py | 2 -- .../modeling_decision_transformer.py | 4 ---- .../deformable_detr/modeling_deformable_detr.py | 2 -- .../models/deprecated/deta/modeling_deta.py | 2 -- .../models/deprecated/ernie_m/modeling_ernie_m.py | 2 -- .../models/deprecated/mctct/modeling_mctct.py | 2 -- .../models/deprecated/nat/modeling_nat.py | 2 -- .../deprecated/transfo_xl/tokenization_transfo_xl.py | 1 - .../models/deprecated/tvlt/modeling_tvlt.py | 2 -- .../models/depth_anything/modeling_depth_anything.py | 2 -- .../models/depth_pro/modeling_depth_pro.py | 2 -- src/transformers/models/detr/modeling_detr.py | 2 -- src/transformers/models/dinat/modeling_dinat.py | 2 -- .../dinov3_convnext/modeling_dinov3_convnext.py | 2 -- .../models/distilbert/modeling_distilbert.py | 2 -- src/transformers/models/donut/modeling_donut_swin.py | 2 -- src/transformers/models/dpr/modeling_dpr.py | 2 -- src/transformers/models/dpt/modeling_dpt.py | 2 -- .../models/efficientnet/modeling_efficientnet.py | 2 -- src/transformers/models/ernie/modeling_ernie.py | 2 -- src/transformers/models/falcon/modeling_falcon.py | 2 -- src/transformers/models/flava/modeling_flava.py | 2 -- src/transformers/models/fnet/modeling_fnet.py | 2 -- src/transformers/models/focalnet/modeling_focalnet.py | 2 -- src/transformers/models/git/modeling_git.py | 2 -- src/transformers/models/glpn/modeling_glpn.py | 2 -- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 2 -- .../models/grounding_dino/modeling_grounding_dino.py | 2 -- src/transformers/models/groupvit/modeling_groupvit.py | 2 -- src/transformers/models/hubert/modeling_hubert.py | 2 -- src/transformers/models/hubert/modular_hubert.py | 2 -- src/transformers/models/ibert/modeling_ibert.py | 2 -- src/transformers/models/layoutlm/modeling_layoutlm.py | 2 -- .../models/layoutlmv2/modeling_layoutlmv2.py | 2 -- .../models/layoutlmv3/modeling_layoutlmv3.py | 2 -- src/transformers/models/levit/modeling_levit.py | 2 -- src/transformers/models/lilt/modeling_lilt.py | 2 -- .../models/longformer/modeling_longformer.py | 2 -- src/transformers/models/marian/tokenization_marian.py | 4 ++-- .../models/maskformer/modeling_maskformer.py | 2 -- .../models/maskformer/modeling_maskformer_swin.py | 2 -- .../mm_grounding_dino/modeling_mm_grounding_dino.py | 2 -- .../models/mobilevit/modeling_mobilevit.py | 2 -- .../models/mobilevitv2/modeling_mobilevitv2.py | 2 -- src/transformers/models/mpnet/modeling_mpnet.py | 2 -- src/transformers/models/mpt/modeling_mpt.py | 2 -- src/transformers/models/mra/modeling_mra.py | 2 -- .../models/nystromformer/modeling_nystromformer.py | 2 -- src/transformers/models/openai/modeling_openai.py | 4 ---- .../models/perceiver/modeling_perceiver.py | 2 -- .../models/pop2piano/tokenization_pop2piano.py | 2 +- src/transformers/models/reformer/modeling_reformer.py | 2 -- .../models/sam/image_processing_sam_fast.py | 2 +- .../models/sam2/image_processing_sam2_fast.py | 2 +- .../models/segformer/modeling_segformer.py | 2 -- src/transformers/models/sew/modeling_sew.py | 2 -- src/transformers/models/sew/modular_sew.py | 2 -- src/transformers/models/sew_d/modeling_sew_d.py | 2 -- src/transformers/models/splinter/modeling_splinter.py | 2 -- .../models/squeezebert/modeling_squeezebert.py | 2 -- .../models/superglue/modeling_superglue.py | 2 -- .../models/superpoint/modeling_superpoint.py | 2 -- src/transformers/models/swin/modeling_swin.py | 2 -- src/transformers/models/swinv2/modeling_swinv2.py | 2 -- .../table_transformer/modeling_table_transformer.py | 2 -- src/transformers/models/tapas/tokenization_tapas.py | 5 ++--- src/transformers/models/tvp/modeling_tvp.py | 2 -- src/transformers/models/videomae/modeling_videomae.py | 2 -- src/transformers/models/vilt/modeling_vilt.py | 2 -- .../models/visual_bert/modeling_visual_bert.py | 2 -- src/transformers/models/vit_mae/modeling_vit_mae.py | 2 -- src/transformers/models/vit_msn/modeling_vit_msn.py | 2 -- src/transformers/models/vivit/modeling_vivit.py | 2 -- .../models/whisper/tokenization_whisper.py | 6 +++--- .../models/whisper/tokenization_whisper_fast.py | 6 +++--- src/transformers/models/yolos/modeling_yolos.py | 2 -- src/transformers/models/yoso/modeling_yoso.py | 2 -- src/transformers/models/zoedepth/modeling_zoedepth.py | 2 -- src/transformers/pipelines/base.py | 2 +- tests/generation/test_utils.py | 2 -- tests/models/blip/test_modeling_blip.py | 2 +- tests/models/rembert/test_modeling_rembert.py | 11 ----------- tests/models/tapas/test_modeling_tapas.py | 4 ---- tests/models/vit_mae/test_modeling_vit_mae.py | 7 ------- .../test_pipelines_zero_shot_audio_classification.py | 8 -------- utils/notification_service.py | 2 +- utils/update_tiny_models.py | 7 ------- 104 files changed, 20 insertions(+), 237 deletions(-) diff --git a/examples/modular-transformers/modeling_test_detr.py b/examples/modular-transformers/modeling_test_detr.py index 11e6719479a4..3ff225c0b3ff 100644 --- a/examples/modular-transformers/modeling_test_detr.py +++ b/examples/modular-transformers/modeling_test_detr.py @@ -846,8 +846,6 @@ def _init_weights(self, module): nn.init.xavier_uniform_(module.output_proj.weight.data) nn.init.constant_(module.output_proj.bias.data, 0.0) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/examples/pytorch/question-answering/utils_qa.py b/examples/pytorch/question-answering/utils_qa.py index b30322b0071f..477e193cbe7b 100644 --- a/examples/pytorch/question-answering/utils_qa.py +++ b/examples/pytorch/question-answering/utils_qa.py @@ -185,7 +185,7 @@ def postprocess_qa_predictions( if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) - # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using + # Compute the softmax of all scores (we do it with numpy to stay independent from torch in this file, using # the LogSumExp trick). scores = np.array([pred.pop("score") for pred in predictions]) exp_scores = np.exp(scores - np.max(scores)) @@ -392,7 +392,7 @@ def postprocess_qa_predictions_with_beam_search( min_null_score = -2e-6 predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score}) - # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using + # Compute the softmax of all scores (we do it with numpy to stay independent from torch in this file, using # the LogSumExp trick). scores = np.array([pred.pop("score") for pred in predictions]) exp_scores = np.exp(scores - np.max(scores)) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 6d06d72353e8..b9423a8bbf59 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -344,10 +344,6 @@ def __init__( logger.error(f"Can't set {key} with value {value} for {self}") raise err - # TODO: remove later, deprecated arguments for TF models - self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) - self.use_bfloat16 = kwargs.pop("use_bfloat16", False) - def _create_id_label_maps(self, num_labels: int): self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)} self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 3d37b3396485..6cec1183c5c7 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -1216,7 +1216,7 @@ def setup(self, args, state, model): - **COMET_PROJECT_NAME** (`str`, *optional*): Comet project name for experiments. - **COMET_LOG_ASSETS** (`str`, *optional*, defaults to `TRUE`): - Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be `TRUE`, or + Whether or not to log training assets (checkpoints, etc), to Comet. Can be `TRUE`, or `FALSE`. For a number of configurable items in the environment, see diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 8770e3e0691b..475b85cf7e8e 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -335,8 +335,6 @@ class BarkPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear,)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 18f4fad10db9..9b6e7f1cd1a6 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -727,8 +727,6 @@ class BeitPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 427adaf659db..99026a2b4fd0 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -579,8 +579,6 @@ class BlipTextPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 6fde63e03b4d..84e31fddfb2e 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -441,8 +441,6 @@ def __init__(self, *inputs, **kwargs): def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 7bb9e0ac762a..5f5dd05ff82d 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -576,8 +576,6 @@ def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index ada06d87f7cc..2ee35cc19a3f 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -982,8 +982,6 @@ def _init_weights(self, module): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index d859c89ecb97..e3224c29405f 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -242,8 +242,6 @@ class ConvNextPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index 8fb27efbe502..3bf6130824ed 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -262,8 +262,6 @@ class ConvNextV2PreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 506bed039b17..03da5b51c907 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -216,8 +216,6 @@ class CTRLPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index 4b7a27e7663b..cbb7450c7f0b 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -824,8 +824,6 @@ def _init_weights(self, module): nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 0fbe9e4802ad..1d901908f818 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -566,8 +566,6 @@ class Data2VecTextPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index c5229c3a7540..e59258625210 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -741,8 +741,6 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 461572b47677..3074db6ca00a 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -38,7 +38,7 @@ class DebertaLayerNorm(nn.Module): - """LayerNorm module in the TF style (epsilon inside the square root).""" + """LayerNorm module (epsilon inside the square root).""" def __init__(self, size, eps=1e-12): super().__init__() @@ -617,8 +617,6 @@ class DebertaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 9d06f00c0ce6..71bf04b95542 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -697,8 +697,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 9aa7860b7d6e..a16c9ab71075 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -410,8 +410,6 @@ def __init__(self, *inputs, **kwargs): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @@ -682,8 +680,6 @@ class DecisionTransformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 34f5bce7a5c4..657e78be87ef 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -959,8 +959,6 @@ def _init_weights(self, module): nn.init.xavier_uniform_(module.output_proj.weight.data) nn.init.constant_(module.output_proj.bias.data, 0.0) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/deprecated/deta/modeling_deta.py b/src/transformers/models/deprecated/deta/modeling_deta.py index c4f6f5c65ded..1e233608e3e4 100644 --- a/src/transformers/models/deprecated/deta/modeling_deta.py +++ b/src/transformers/models/deprecated/deta/modeling_deta.py @@ -1041,8 +1041,6 @@ def _init_weights(self, module): elif isinstance(module, DetaMultiscaleDeformableAttention): module._reset_parameters() elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py index f0e97c132d09..4cecdf5728a3 100755 --- a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py @@ -409,8 +409,6 @@ class ErnieMPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index f4982935336d..357b8b2c3681 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -425,8 +425,6 @@ def _init_weights(self, module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/deprecated/nat/modeling_nat.py b/src/transformers/models/deprecated/nat/modeling_nat.py index 935ae89ef966..1667c98297fa 100644 --- a/src/transformers/models/deprecated/nat/modeling_nat.py +++ b/src/transformers/models/deprecated/nat/modeling_nat.py @@ -615,8 +615,6 @@ class NatPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py b/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py index 70e2da018556..e7081cd46d42 100644 --- a/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py +++ b/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py @@ -201,7 +201,6 @@ def __init__( try: vocab_dict = None if pretrained_vocab_file is not None: - # Priority on pickle files (support PyTorch and TF) if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")): raise ValueError( "This part uses `pickle.load` which is insecure and will execute arbitrary code that is " diff --git a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py index 2b21df928ff3..82aa12ada9e9 100644 --- a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py +++ b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py @@ -579,8 +579,6 @@ class TvltPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index bc7d74131204..5710016bd513 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -218,8 +218,6 @@ class DepthAnythingPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/depth_pro/modeling_depth_pro.py b/src/transformers/models/depth_pro/modeling_depth_pro.py index 52de04d42df7..9fb4c35b23e5 100644 --- a/src/transformers/models/depth_pro/modeling_depth_pro.py +++ b/src/transformers/models/depth_pro/modeling_depth_pro.py @@ -618,8 +618,6 @@ class DepthProPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 86835ca62cfc..89441a8b1246 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -739,8 +739,6 @@ def _init_weights(self, module): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 80349f29592d..a65b4862c473 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -583,8 +583,6 @@ class DinatPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py b/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py index da68b17c8587..8eef42c03d17 100644 --- a/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +++ b/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py @@ -194,8 +194,6 @@ class DINOv3ConvNextPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 8024e01a59bf..8f0cdcd76898 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -570,8 +570,6 @@ class DistilBertPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 21f400cb2b68..c541b960fd2e 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -829,8 +829,6 @@ class DonutSwinPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 03178c4e8564..7ee4dcaf52e1 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -108,8 +108,6 @@ class DPRPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 363fce92f897..7be71fd3ceb4 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -758,8 +758,6 @@ class DPTPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/efficientnet/modeling_efficientnet.py b/src/transformers/models/efficientnet/modeling_efficientnet.py index a263ff20760c..e368fefa0e79 100644 --- a/src/transformers/models/efficientnet/modeling_efficientnet.py +++ b/src/transformers/models/efficientnet/modeling_efficientnet.py @@ -438,8 +438,6 @@ class EfficientNetPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index f244efb3f01c..4c7c33fd7e43 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -626,8 +626,6 @@ class ErniePreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 26dc56e41480..dac4dc658b19 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -655,8 +655,6 @@ def __init__(self, *inputs, **kwargs): def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, (nn.Linear, FalconLinear)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 8705515f4270..c48f2ca1279f 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -699,8 +699,6 @@ class FlavaPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 3ed84887ce6f..b8cdd1f2ea58 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -390,8 +390,6 @@ class FNetPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # NOTE: Original code uses same initialization as weights for biases as well. if module.bias is not None: diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 9501db673037..9b5d4daed70c 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -584,8 +584,6 @@ class FocalNetPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 6b99722aa6a8..b98d2b1c231c 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -461,8 +461,6 @@ def _init_weights(self, module): nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range) nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range) if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index dc82249f2b7a..abdbcbf10e79 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -412,8 +412,6 @@ class GLPNPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 6992dc642a4f..63b2ec4039f6 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -383,8 +383,6 @@ def _init_weights(self, module): ) module.c_proj._is_hf_initialized = True elif isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 727749548042..594524c8dd1c 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -1413,8 +1413,6 @@ def _init_weights(self, module): module.vision_param.data.fill_(1e-4) module.text_param.data.fill_(1e-4) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 775ebd286f0a..65fdaaa784d3 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -752,8 +752,6 @@ def _init_weights(self, module): init_range = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=init_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 060b715e8d49..dfa53a2cf193 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -688,8 +688,6 @@ class HubertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py index facebcf445e6..d7169a85d30b 100644 --- a/src/transformers/models/hubert/modular_hubert.py +++ b/src/transformers/models/hubert/modular_hubert.py @@ -136,8 +136,6 @@ class HubertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index a9ab176d0bab..e1b3c7fb966c 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -626,8 +626,6 @@ class IBertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (QuantLinear, nn.Linear)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 9e71eb7d8fb9..11b7fac2b78c 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -466,8 +466,6 @@ class LayoutLMPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 3f444fbb6b28..f3b856518133 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -473,8 +473,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 73bf26b0dfbe..63631e12eab5 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -205,8 +205,6 @@ class LayoutLMv3PreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py index 3deca07e2400..bec62dec56e0 100644 --- a/src/transformers/models/levit/modeling_levit.py +++ b/src/transformers/models/levit/modeling_levit.py @@ -474,8 +474,6 @@ class LevitPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 4ce6b3b28328..c486a494b48a 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -565,8 +565,6 @@ class LiltPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index c42ace1aae2a..fc466a38ecc2 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1355,8 +1355,6 @@ class LongformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/marian/tokenization_marian.py b/src/transformers/models/marian/tokenization_marian.py index ef8e1537b99d..66a3630ffd56 100644 --- a/src/transformers/models/marian/tokenization_marian.py +++ b/src/transformers/models/marian/tokenization_marian.py @@ -204,7 +204,7 @@ def batch_decode(self, sequences, **kwargs): Convert a list of lists of token ids into a list of strings by calling decode. Args: - sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. @@ -230,7 +230,7 @@ def decode(self, token_ids, **kwargs): Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. Args: - token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 9e1c0072425b..02deeb4af638 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -1464,8 +1464,6 @@ def _init_weights(self, module: nn.Module): module.weight.data.fill_(1.0) # copied from DETR if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 018c33d377ca..2de478440414 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -738,8 +738,6 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py index 3202a5e80a02..b27d6ac42a3a 100644 --- a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +++ b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py @@ -550,8 +550,6 @@ def _init_weights(self, module): module.vision_param.data.fill_(1e-4) module.text_param.data.fill_(1e-4) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index b74987040fc7..db8b8cd58f5a 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -629,8 +629,6 @@ class MobileViTPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index 0e9e616c52cd..d842acf7b6e5 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -576,8 +576,6 @@ class MobileViTV2PreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/mpnet/modeling_mpnet.py b/src/transformers/models/mpnet/modeling_mpnet.py index b25e5491738b..e2ea5cf300ad 100644 --- a/src/transformers/models/mpnet/modeling_mpnet.py +++ b/src/transformers/models/mpnet/modeling_mpnet.py @@ -49,8 +49,6 @@ class MPNetPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index c7bf0a795d42..18df615794cf 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -232,8 +232,6 @@ def __init__(self, *inputs, **kwargs): def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index c5584cb29cb1..6612336b6794 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -822,8 +822,6 @@ def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 8b67ab7001c1..03c134ccadae 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -447,8 +447,6 @@ class NystromformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index b9de7b999a09..a1b6bf2ed579 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -46,7 +46,6 @@ class Attention(nn.Module): def __init__(self, nx, n_positions, config, scale=False): super().__init__() n_state = nx # in Attention: n_state=768 (nx=n_embd) - # [switch nx => n_state from Block to Attention to keep identical to TF implementation] if n_state % config.n_head != 0: raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}") self.register_buffer( @@ -83,7 +82,6 @@ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions= w = torch.matmul(q, k) if self.scale: w = w / math.sqrt(v.size(-1)) - # w = w * self.bias + -1e9 * (1 - self.bias) # TF implementation method: mask_attn_weights # XD: self.b may be larger than w, so we need to crop it b = self.bias[:, :, : w.size(-2), : w.size(-1)] w = w * b + -1e4 * (1 - b) @@ -284,8 +282,6 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index ef0786074f9f..58267db8c19a 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -566,8 +566,6 @@ class PerceiverPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/pop2piano/tokenization_pop2piano.py b/src/transformers/models/pop2piano/tokenization_pop2piano.py index c81165a03be4..0c2feec52d57 100644 --- a/src/transformers/models/pop2piano/tokenization_pop2piano.py +++ b/src/transformers/models/pop2piano/tokenization_pop2piano.py @@ -608,7 +608,7 @@ def batch_decode( transformer to midi_notes and returns them. Args: - token_ids (`Union[np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[np.ndarray, torch.Tensor]`): Output token_ids of `Pop2PianoConditionalGeneration` model. feature_extractor_output (`BatchFeature`): Denotes the output of `Pop2PianoFeatureExtractor.__call__`. It must contain `"beatstep"` and diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 990f21359bc0..2160663ed5ba 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -1920,8 +1920,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/sam/image_processing_sam_fast.py b/src/transformers/models/sam/image_processing_sam_fast.py index ba75e73c8680..65ee02e97dac 100644 --- a/src/transformers/models/sam/image_processing_sam_fast.py +++ b/src/transformers/models/sam/image_processing_sam_fast.py @@ -270,7 +270,7 @@ def generate_crop_boxes( input_data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. return_tensors (`str`, *optional*, defaults to `pt`): - If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + If `pt`, returns `torch.Tensor`. """ image = self._process_image(image) crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes( diff --git a/src/transformers/models/sam2/image_processing_sam2_fast.py b/src/transformers/models/sam2/image_processing_sam2_fast.py index a55188f4e786..f78e8b65bea1 100644 --- a/src/transformers/models/sam2/image_processing_sam2_fast.py +++ b/src/transformers/models/sam2/image_processing_sam2_fast.py @@ -530,7 +530,7 @@ def generate_crop_boxes( input_data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. return_tensors (`str`, *optional*, defaults to `pt`): - If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + If `pt`, returns `torch.Tensor`. """ image = self._process_image(image) crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes( diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index a60eff60c65b..0b06c7c39e09 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -436,8 +436,6 @@ class SegformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index d5bd617e842b..a001cdd61d58 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -529,8 +529,6 @@ def _init_weights(self, module): ) nn.init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py index b15c2e5c23a3..1cba7595079e 100644 --- a/src/transformers/models/sew/modular_sew.py +++ b/src/transformers/models/sew/modular_sew.py @@ -276,8 +276,6 @@ def _init_weights(self, module): ) nn.init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index f8b71241c79e..68d406c5464a 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1197,8 +1197,6 @@ def _init_weights(self, module): ) nn.init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 3d1f6138b00e..1d9c02877841 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -367,8 +367,6 @@ class SplinterPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/squeezebert/modeling_squeezebert.py b/src/transformers/models/squeezebert/modeling_squeezebert.py index 7e58dacd7f84..9e26d1953f1c 100644 --- a/src/transformers/models/squeezebert/modeling_squeezebert.py +++ b/src/transformers/models/squeezebert/modeling_squeezebert.py @@ -421,8 +421,6 @@ class SqueezeBertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv1d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index 61c52d06e605..4fc524314e89 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -524,8 +524,6 @@ class SuperGluePreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/superpoint/modeling_superpoint.py b/src/transformers/models/superpoint/modeling_superpoint.py index efd3113eb3b0..f75cc6f9bb8f 100644 --- a/src/transformers/models/superpoint/modeling_superpoint.py +++ b/src/transformers/models/superpoint/modeling_superpoint.py @@ -330,8 +330,6 @@ class SuperPointPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index e6cb0afc5781..7f9e04337ba4 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -851,8 +851,6 @@ class SwinPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 424902ebdee3..ddc4dab73768 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -922,8 +922,6 @@ class Swinv2PreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index d55ea6740075..4ab85689ab15 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -700,8 +700,6 @@ def _init_weights(self, module): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/tapas/tokenization_tapas.py b/src/transformers/models/tapas/tokenization_tapas.py index 098bee1dd7f5..7277f562a118 100644 --- a/src/transformers/models/tapas/tokenization_tapas.py +++ b/src/transformers/models/tapas/tokenization_tapas.py @@ -1896,9 +1896,9 @@ def convert_logits_to_predictions(self, data, logits, logits_agg=None, cell_clas Args: data (`dict`): Dictionary mapping features to actual values. Should be created using [`TapasTokenizer`]. - logits (`torch.Tensor` or `tf.Tensor` of shape `(batch_size, sequence_length)`): + logits (`torch.Tensor` of shape `(batch_size, sequence_length)`): Tensor containing the logits at the token level. - logits_agg (`torch.Tensor` or `tf.Tensor` of shape `(batch_size, num_aggregation_labels)`, *optional*): + logits_agg (`torch.Tensor` of shape `(batch_size, num_aggregation_labels)`, *optional*): Tensor containing the aggregation logits. cell_classification_threshold (`float`, *optional*, defaults to 0.5): Threshold to be used for cell selection. All table cells for which their probability is larger than @@ -1913,7 +1913,6 @@ def convert_logits_to_predictions(self, data, logits, logits_agg=None, cell_clas - predicted_aggregation_indices (`list[int]`of length `batch_size`, *optional*, returned when `logits_aggregation` is provided): Predicted aggregation operator indices of the aggregation head. """ - # converting to numpy arrays to work with PT/TF logits = logits.numpy() if logits_agg is not None: logits_agg = logits_agg.numpy() diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index 0b8b626d2dd2..dcbd220331f9 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -557,8 +557,6 @@ class TvpPreTrainedModel(PreTrainedModel): def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 97c227f1d8bf..d249e65c5a45 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -414,8 +414,6 @@ class VideoMAEPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv3d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 09ee7f07357d..8535b3c747e2 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -546,8 +546,6 @@ class ViltPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index ecb1a6cd1594..f0277a7bd820 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -500,8 +500,6 @@ class VisualBertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if hasattr(module, "bias") and module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index a74d172805bd..72c90af31f81 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -555,8 +555,6 @@ class ViTMAEPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index eee739b13864..d66d94fcf56e 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -395,8 +395,6 @@ class ViTMSNPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index b27b56e640c6..a18bcc49bf5c 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -400,8 +400,6 @@ class VivitPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv3d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 5e0d04acef1f..34d9a8965be8 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -577,7 +577,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): Compute offsets for a given tokenized input Args: - token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, *optional*, defaults to 0.02): The time ratio to convert from token to time. @@ -656,7 +656,7 @@ def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False): Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids. Args: - token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be @@ -692,7 +692,7 @@ def decode( Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. Args: - token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. Will remove the previous tokens (pre-prompt) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 235b356edab4..fbcf8ea757bd 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -210,7 +210,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): Compute offsets for a given tokenized input Args: - token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, *optional*, defaults to 0.02): The time ratio to convert from token to time. @@ -291,7 +291,7 @@ def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False): Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids. Args: - token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be @@ -329,7 +329,7 @@ def decode( Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. Args: - token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. Will remove the previous tokens (pre-prompt) diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 13fd9886ea96..7677dcae64a7 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -471,8 +471,6 @@ class YolosPreTrainedModel(PreTrainedModel): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 80c0dfd6539e..b1d8e5e752a1 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -644,8 +644,6 @@ def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index 7bbad31c2ee0..f03804c2c57b 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -1213,8 +1213,6 @@ class ZoeDepthPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index db7a0d77da64..61c0aff4e029 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -897,7 +897,7 @@ def __init__( # each pipeline with text generation capabilities should define its own default generation in a # `_default_generation_config` class attribute default_pipeline_generation_config = getattr(self, "_default_generation_config", GenerationConfig()) - if hasattr(self.model, "_prepare_generation_config"): # TF doesn't have `_prepare_generation_config` + if hasattr(self.model, "_prepare_generation_config"): # Uses `generate`'s logic to enforce the following priority of arguments: # 1. user-defined config options in `**kwargs` # 2. model's generation config values diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 094c5861ab10..b8931c9988f6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3252,7 +3252,6 @@ def test_logits_processor_not_inplace(self): self.assertNotEqual(out_with_temp.logits[-1].tolist(), out_with_temp.scores[-1].tolist()) def test_eos_token_id_int_and_list_top_k_top_sampling(self): - # Has TF equivalent: this test relies on random sampling generation_kwargs = { "do_sample": True, "num_beams": 1, @@ -3280,7 +3279,6 @@ def test_eos_token_id_int_and_list_top_k_top_sampling(self): self.assertTrue(expectation == len(generated_tokens[0])) def test_model_kwarg_encoder_signature_filtering(self): - # Has TF equivalent: ample use of framework-specific code bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") article = """Hugging Face is a technology company based in New York and Paris.""" input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 189773afd399..a59cf4fefffa 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -705,7 +705,7 @@ def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=Tru self.text_model_tester = BlipTextModelTester(parent, **text_kwargs) self.vision_model_tester = BlipVisionModelTester(parent, **vision_kwargs) self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test - self.seq_length = self.text_model_tester.seq_length # need seq_length for pt-tf equivalence test + self.seq_length = self.text_model_tester.seq_length self.is_training = is_training def prepare_config_and_inputs(self): diff --git a/tests/models/rembert/test_modeling_rembert.py b/tests/models/rembert/test_modeling_rembert.py index e7804c59799b..93a16866601b 100644 --- a/tests/models/rembert/test_modeling_rembert.py +++ b/tests/models/rembert/test_modeling_rembert.py @@ -477,17 +477,6 @@ def test_inference_model(self): ] ) - # Running on the original tf implementation gives slightly different results here. - # Not clear why this variations is present - # TODO: Find reason for discrepancy - # expected_original_implementation = [[ - # [0.07630594074726105, -0.20146065950393677, 0.19107051193714142], - # [-0.3405614495277405, -0.36971670389175415, -0.4808273911476135], - # [-0.22587086260318756, -0.6656315922737122, -0.07844287157058716], - # [-0.04145475849509239, -0.3077218234539032, -0.42316967248916626], - # [-0.15887849032878876, -0.054529931396245956, 0.5356100797653198] - # ]] - torch.testing.assert_close( output["last_hidden_state"][:, :, :3], expected_implementation, rtol=1e-4, atol=1e-4 ) diff --git a/tests/models/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py index 9ffd1ca6a80f..65e5e4d2758a 100644 --- a/tests/models/tapas/test_modeling_tapas.py +++ b/tests/models/tapas/test_modeling_tapas.py @@ -576,9 +576,6 @@ def default_tokenizer(self): @slow def test_inference_no_head(self): - # ideally we want to test this with the weights of tapas_inter_masklm_base_reset, - # but since it's not straightforward to do this with the TF 1 implementation, we test it with - # the weights of the WTQ base model (i.e. tapas_wtq_wikisql_sqa_inter_masklm_base_reset) model = TapasModel.from_pretrained("google/tapas-base-finetuned-wtq").to(torch_device) tokenizer = self.default_tokenizer @@ -767,7 +764,6 @@ def test_training_question_answering_head_weak_supervision(self): # note that google/tapas-base-finetuned-wtq should correspond to tapas_wtq_wikisql_sqa_inter_masklm_base_reset model = TapasForQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq").to(torch_device) model.to(torch_device) - # normally we should put the model in training mode but it's a pain to do this with the TF 1 implementation tokenizer = self.default_tokenizer # let's test on a batch diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index a79bcec8af72..b28d4711d589 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -358,7 +358,6 @@ def default_model(self): @slow def test_inference_for_pretraining(self): - # make random mask reproducible across the PT and TF model np.random.seed(2) model = self.default_model @@ -367,8 +366,6 @@ def test_inference_for_pretraining(self): image = prepare_img() inputs = image_processor(images=image, return_tensors="pt").to(torch_device) - # prepare a noise vector that will be also used for testing the TF model - # (this way we can ensure that the PT and TF models operate on the same inputs) vit_mae_config = ViTMAEConfig() num_patches = int((vit_mae_config.image_size // vit_mae_config.patch_size) ** 2) noise = torch.from_numpy(np.random.uniform(size=(1, num_patches))).to(device=torch_device) @@ -394,7 +391,6 @@ def test_inference_interpolate_pos_encoding(self): # the model on higher resolutions. The DINO model by Facebook AI leverages this # to visualize self-attention on higher resolution images. - # make random mask reproducible across the PT and TF model np.random.seed(2) model = self.default_model @@ -403,8 +399,6 @@ def test_inference_interpolate_pos_encoding(self): image = prepare_img() inputs = image_processor(images=image, return_tensors="pt", do_resize=False).to(torch_device) - # prepare a noise vector that will be also used for testing the TF model - # (this way we can ensure that the PT and TF models operate on the same inputs) vit_mae_config = ViTMAEConfig() num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size) noise = torch.from_numpy(np.random.uniform(size=(1, num_patches))).to(device=torch_device) @@ -421,7 +415,6 @@ def test_inference_interpolate_pos_encoding(self): def test_inference_interpolate_pos_encoding_custom_sizes(self): # Ensure custom sizes are correctly handled when interpolating the position embeddings - # make random mask reproducible across the PT and TF model np.random.seed(2) model = self.default_model diff --git a/tests/pipelines/test_pipelines_zero_shot_audio_classification.py b/tests/pipelines/test_pipelines_zero_shot_audio_classification.py index 0f884240bf12..acd64d7705fb 100644 --- a/tests/pipelines/test_pipelines_zero_shot_audio_classification.py +++ b/tests/pipelines/test_pipelines_zero_shot_audio_classification.py @@ -46,10 +46,6 @@ def test_small_model_pt(self, dtype="float32"): def test_small_model_pt_fp16(self): self.test_small_model_pt(dtype="float16") - @unittest.skip(reason="No models are available in TF") - def test_small_model_tf(self): - pass - @slow @require_torch def test_large_model_pt(self): @@ -94,7 +90,3 @@ def test_large_model_pt(self): ] * 5, ) - - @unittest.skip(reason="No models are available in TF") - def test_large_model_tf(self): - pass diff --git a/utils/notification_service.py b/utils/notification_service.py index ca2085c74b25..f5c1a7d2b132 100644 --- a/utils/notification_service.py +++ b/utils/notification_service.py @@ -445,7 +445,7 @@ def per_model_sum(model_category_dict): # (Possibly truncated) reports for the current workflow run - to be sent to Slack channels if job_name == "run_models_gpu": - model_header = "Single PT | Multi PT | Single TF | Multi TF | Other | Category\n" + model_header = "Single PT | Multi PT | Other | Category\n" else: model_header = "Single | Multi | Category\n" diff --git a/utils/update_tiny_models.py b/utils/update_tiny_models.py index d5cb048ea639..ee81407d4124 100644 --- a/utils/update_tiny_models.py +++ b/utils/update_tiny_models.py @@ -135,13 +135,6 @@ def get_tiny_model_summary_from_hub(output_path): content["model_classes"].add(m.__class__.__name__) except Exception: pass - try: - time.sleep(1) - model_class = getattr(transformers, f"TF{model}") - m = model_class.from_pretrained(repo_id) - content["model_classes"].add(m.__class__.__name__) - except Exception: - pass content["tokenizer_classes"] = sorted(content["tokenizer_classes"]) content["processor_classes"] = sorted(content["processor_classes"]) From 72b2a287d85741ab88e30abb6213134f5d33dd6a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 18:59:54 +0200 Subject: [PATCH 27/35] nearly the end --- .../models/pegasus_x/modeling_pegasus_x.py | 3 - src/transformers/onnx/convert.py | 1 - tests/models/fnet/test_modeling_fnet.py | 55 ------------------- .../pipelines/test_pipelines_text_to_audio.py | 2 +- tests/sagemaker/conftest.py | 7 +-- .../test_multi_node_data_parallel.py | 13 ++--- .../test_multi_node_model_parallel.py | 13 ++--- tests/sagemaker/test_single_node_gpu.py | 12 ++-- 8 files changed, 19 insertions(+), 87 deletions(-) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 0c1ae32cabe2..231d6601d28a 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -67,9 +67,6 @@ class DimensionInfo: global_len: int # global length padded_seq_len: int # padded token seq length - # Note: Compared to the original Flax implementation, we will pad the token representations to - # a multiple of block size at the start of the encoder layers, so T=P always. - # Copied from transformers.models.bart.modeling_bart.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index f76f08b23db5..bcf7fc878890 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -121,7 +121,6 @@ def export_pytorch( import torch from torch.onnx import export as onnx_export - logger.info(f"Using framework PyTorch: {torch.__version__}") with torch.no_grad(): model.config.return_dict = True model.eval() diff --git a/tests/models/fnet/test_modeling_fnet.py b/tests/models/fnet/test_modeling_fnet.py index eca30656dd20..f2bc7b099437 100644 --- a/tests/models/fnet/test_modeling_fnet.py +++ b/tests/models/fnet/test_modeling_fnet.py @@ -439,61 +439,6 @@ def test_model_from_pretrained(self): class FNetModelIntegrationTest(unittest.TestCase): @slow def test_inference_for_masked_lm(self): - """ - For comparison: - 1. Modify the pre-training model `__call__` to skip computing metrics and return masked_lm_output like so: - ``` - ... - sequence_output, pooled_output = EncoderModel( - self.config, random_seed=self.random_seed, name="encoder")( - input_ids, input_mask, type_ids, deterministic=deterministic) - - masked_lm_output = nn.Dense( - self.config.d_emb, - kernel_init=default_kernel_init, - name="predictions_dense")( - sequence_output) - masked_lm_output = nn.gelu(masked_lm_output) - masked_lm_output = nn.LayerNorm( - epsilon=LAYER_NORM_EPSILON, name="predictions_layer_norm")( - masked_lm_output) - masked_lm_logits = layers.OutputProjection( - kernel=self._get_embedding_table(), name="predictions_output")( - masked_lm_output) - - next_sentence_logits = layers.OutputProjection( - n_out=2, kernel_init=default_kernel_init, name="classification")( - pooled_output) - - return masked_lm_logits - ... - ``` - 2. Run the following: - >>> import jax.numpy as jnp - >>> import sentencepiece as spm - >>> from flax.training import checkpoints - >>> from f_net.models import PreTrainingModel - >>> from f_net.configs.pretraining import get_config, ModelArchitecture - - >>> pretrained_params = checkpoints.restore_checkpoint('./f_net/f_net_checkpoint', None) # Location of original checkpoint - >>> pretrained_config = get_config() - >>> pretrained_config.model_arch = ModelArchitecture.F_NET - - >>> vocab_filepath = "./f_net/c4_bpe_sentencepiece.model" # Location of the sentence piece model - >>> tokenizer = spm.SentencePieceProcessor() - >>> tokenizer.Load(vocab_filepath) - >>> with pretrained_config.unlocked(): - >>> pretrained_config.vocab_size = tokenizer.GetPieceSize() - >>> tokens = jnp.array([[0, 1, 2, 3, 4, 5]]) - >>> type_ids = jnp.zeros_like(tokens, dtype="i4") - >>> attention_mask = jnp.ones_like(tokens) # Dummy. This gets deleted inside the model. - - >>> flax_pretraining_model = PreTrainingModel(pretrained_config) - >>> pretrained_model_params = freeze(pretrained_params['target']) - >>> flax_model_outputs = flax_pretraining_model.apply({"params": pretrained_model_params}, tokens, attention_mask, type_ids, None, None, None, None, deterministic=True) - >>> masked_lm_logits[:, :3, :3] - """ - model = FNetForMaskedLM.from_pretrained("google/fnet-base") model.to(torch_device) diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index fb56d0c64b54..c13d0830c6e6 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -247,7 +247,7 @@ def test_generative_model_kwargs(self): @slow @require_torch def test_csm_model_pt(self): - speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b", framework="pt") + speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b") outputs = speech_generator("[0]This is a test") self.assertEqual(outputs["sampling_rate"], 24000) diff --git a/tests/sagemaker/conftest.py b/tests/sagemaker/conftest.py index 4dbfac3588ee..5daf3c4147f9 100644 --- a/tests/sagemaker/conftest.py +++ b/tests/sagemaker/conftest.py @@ -12,7 +12,6 @@ @dataclass class SageMakerTestEnvironment: - framework: str = "pytorch" role = "arn:aws:iam::558105141721:role/sagemaker_execution_role" hyperparameters = { "task_name": "mnli", @@ -38,11 +37,11 @@ def metric_definitions(self) -> str: @property def base_job_name(self) -> str: - return f"{self.framework}-transformers-test" + return "pytorch-transformers-test" @property def test_path(self) -> str: - return f"./tests/sagemaker/scripts/{self.framework}" + return "./tests/sagemaker/scripts/pytorch" @property def image_uri(self) -> str: @@ -51,4 +50,4 @@ def image_uri(self) -> str: @pytest.fixture(scope="class") def sm_env(request): - request.cls.env = SageMakerTestEnvironment(framework=request.cls.framework) + request.cls.env = SageMakerTestEnvironment() diff --git a/tests/sagemaker/test_multi_node_data_parallel.py b/tests/sagemaker/test_multi_node_data_parallel.py index 602a90e6d8e8..3f2fd7c04166 100644 --- a/tests/sagemaker/test_multi_node_data_parallel.py +++ b/tests/sagemaker/test_multi_node_data_parallel.py @@ -23,14 +23,12 @@ @parameterized_class( [ { - "framework": "pytorch", "script": "run_glue.py", "model_name_or_path": "distilbert/distilbert-base-cased", "instance_type": "ml.p3.16xlarge", "results": {"train_runtime": 650, "eval_accuracy": 0.7, "eval_loss": 0.6}, }, { - "framework": "pytorch", "script": "run_ddp.py", "model_name_or_path": "distilbert/distilbert-base-cased", "instance_type": "ml.p3.16xlarge", @@ -40,12 +38,11 @@ ) class MultiNodeTest(unittest.TestCase): def setUp(self): - if self.framework == "pytorch": - subprocess.run( - f"cp ./examples/pytorch/text-classification/run_glue.py {self.env.test_path}/run_glue.py".split(), - encoding="utf-8", - check=True, - ) + subprocess.run( + f"cp ./examples/pytorch/text-classification/run_glue.py {self.env.test_path}/run_glue.py".split(), + encoding="utf-8", + check=True, + ) assert hasattr(self, "env") def create_estimator(self, instance_count): diff --git a/tests/sagemaker/test_multi_node_model_parallel.py b/tests/sagemaker/test_multi_node_model_parallel.py index 216d31de4710..818028a75783 100644 --- a/tests/sagemaker/test_multi_node_model_parallel.py +++ b/tests/sagemaker/test_multi_node_model_parallel.py @@ -23,14 +23,12 @@ @parameterized_class( [ { - "framework": "pytorch", "script": "run_glue_model_parallelism.py", "model_name_or_path": "FacebookAI/roberta-large", "instance_type": "ml.p3dn.24xlarge", "results": {"train_runtime": 1600, "eval_accuracy": 0.3, "eval_loss": 1.2}, }, { - "framework": "pytorch", "script": "run_glue.py", "model_name_or_path": "FacebookAI/roberta-large", "instance_type": "ml.p3dn.24xlarge", @@ -40,12 +38,11 @@ ) class MultiNodeTest(unittest.TestCase): def setUp(self): - if self.framework == "pytorch": - subprocess.run( - f"cp ./examples/pytorch/text-classification/run_glue.py {self.env.test_path}/run_glue.py".split(), - encoding="utf-8", - check=True, - ) + subprocess.run( + f"cp ./examples/pytorch/text-classification/run_glue.py {self.env.test_path}/run_glue.py".split(), + encoding="utf-8", + check=True, + ) assert hasattr(self, "env") def create_estimator(self, instance_count): diff --git a/tests/sagemaker/test_single_node_gpu.py b/tests/sagemaker/test_single_node_gpu.py index c1902797391d..7fc764810eb0 100644 --- a/tests/sagemaker/test_single_node_gpu.py +++ b/tests/sagemaker/test_single_node_gpu.py @@ -23,7 +23,6 @@ @parameterized_class( [ { - "framework": "pytorch", "script": "run_glue.py", "model_name_or_path": "distilbert/distilbert-base-cased", "instance_type": "ml.g4dn.xlarge", @@ -33,12 +32,11 @@ ) class SingleNodeTest(unittest.TestCase): def setUp(self): - if self.framework == "pytorch": - subprocess.run( - f"cp ./examples/pytorch/text-classification/run_glue.py {self.env.test_path}/run_glue.py".split(), - encoding="utf-8", - check=True, - ) + subprocess.run( + f"cp ./examples/pytorch/text-classification/run_glue.py {self.env.test_path}/run_glue.py".split(), + encoding="utf-8", + check=True, + ) assert hasattr(self, "env") def create_estimator(self, instance_count=1): From 7c7d17667324c23c17fd21015617595945629e16 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 11 Sep 2025 19:09:17 +0200 Subject: [PATCH 28/35] could it really be the end? --- .../layoutlmv2/test_tokenization_layoutlmv2.py | 2 +- .../layoutlmv3/test_tokenization_layoutlmv3.py | 2 +- .../layoutxlm/test_tokenization_layoutxlm.py | 2 +- .../models/markuplm/test_tokenization_markuplm.py | 2 +- tests/models/tapas/test_tokenization_tapas.py | 2 +- tests/models/udop/test_tokenization_udop.py | 2 +- tests/pipelines/test_pipelines_common.py | 14 +++++++------- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py index 5a2a008db02f..b3e7adc68257 100644 --- a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py @@ -2398,7 +2398,7 @@ def test_layoutlmv2_integration_test(self): self.assertDictEqual(dict(encoding_p), expected_results) self.assertDictEqual(dict(encoding_r), expected_results) - @unittest.skip(reason="Doesn't support another framework than PyTorch") + @unittest.skip(reason="Doesn't support returning Numpy arrays") def test_np_encode_plus_sent_to_model(self): pass diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py index 1568a7e01104..729a7f4034f7 100644 --- a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py @@ -2322,7 +2322,7 @@ def test_layoutlmv3_integration_test(self): self.assertDictEqual(dict(encoding_p), expected_results) self.assertDictEqual(dict(encoding_r), expected_results) - @unittest.skip(reason="Doesn't support another framework than PyTorch") + @unittest.skip(reason="Doesn't support returning Numpy arrays") def test_np_encode_plus_sent_to_model(self): pass diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py index 5c09a14ecaa6..da77177a3f62 100644 --- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py +++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py @@ -1884,7 +1884,7 @@ def test_layoutxlm_integration_test(self): self.assertDictEqual(dict(encoding_p), expected_results) self.assertDictEqual(dict(encoding_r), expected_results) - @unittest.skip(reason="Doesn't support another framework than PyTorch") + @unittest.skip(reason="Doesn't support returning Numpy arrays") def test_np_encode_plus_sent_to_model(self): pass diff --git a/tests/models/markuplm/test_tokenization_markuplm.py b/tests/models/markuplm/test_tokenization_markuplm.py index 0d5ef0efdb02..8232269c53d4 100644 --- a/tests/models/markuplm/test_tokenization_markuplm.py +++ b/tests/models/markuplm/test_tokenization_markuplm.py @@ -2195,7 +2195,7 @@ def test_markuplm_integration_test(self): self.assertDictEqual(dict(encoding_p), expected_results) self.assertDictEqual(dict(encoding_r), expected_results) - @unittest.skip(reason="Doesn't support another framework than PyTorch") + @unittest.skip(reason="Doesn't support returning Numpy arrays") def test_np_encode_plus_sent_to_model(self): pass diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py index 1f95919e27db..6dc57b94b064 100644 --- a/tests/models/tapas/test_tokenization_tapas.py +++ b/tests/models/tapas/test_tokenization_tapas.py @@ -1150,7 +1150,7 @@ def test_full_tokenizer(self): self.assertListEqual(column_ids.tolist(), expected_results["column_ids"]) self.assertListEqual(row_ids.tolist(), expected_results["row_ids"]) - @unittest.skip(reason="Doesn't support another framework than PyTorch") + @unittest.skip(reason="Doesn't support returning Numpy arrays") def test_np_encode_plus_sent_to_model(self): pass diff --git a/tests/models/udop/test_tokenization_udop.py b/tests/models/udop/test_tokenization_udop.py index f5270ee84d2c..a7d301607e84 100644 --- a/tests/models/udop/test_tokenization_udop.py +++ b/tests/models/udop/test_tokenization_udop.py @@ -1773,7 +1773,7 @@ def test_udop_integration_test(self): self.assertDictEqual(dict(encoding_p), expected_results) self.assertDictEqual(dict(encoding_r), expected_results) - @unittest.skip(reason="Doesn't support another framework than PyTorch") + @unittest.skip(reason="Doesn't support returning Numpy arrays") def test_np_encode_plus_sent_to_model(self): pass diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d94f09987b20..aa3b1b55beba 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -564,7 +564,7 @@ def test_load_default_pipelines_pt(self): # test table in separate test due to more dependencies continue - self.check_default_pipeline(task, "pt", set_seed_fn, self.check_models_equal_pt) + self.check_default_pipeline(task, set_seed_fn, self.check_models_equal_pt) # clean-up as much as possible GPU memory occupied by PyTorch gc.collect() @@ -576,7 +576,7 @@ def test_load_default_pipelines_pt_table_qa(self): import torch set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731 - self.check_default_pipeline("table-question-answering", "pt", set_seed_fn, self.check_models_equal_pt) + self.check_default_pipeline("table-question-answering", set_seed_fn, self.check_models_equal_pt) # clean-up as much as possible GPU memory occupied by PyTorch gc.collect() @@ -624,17 +624,17 @@ def test_bc_torch_device(self): self.assertEqual(k1, k2) self.assertEqual(v1.dtype, v2.dtype) - def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equal_fn): + def check_default_pipeline(self, task, set_seed_fn, check_models_equal_fn): from transformers.pipelines import SUPPORTED_TASKS, pipeline task_dict = SUPPORTED_TASKS[task] # test to compare pipeline to manually loading the respective model model = None - relevant_auto_classes = task_dict[framework] + relevant_auto_classes = task_dict["pt"] if len(relevant_auto_classes) == 0: # task has no default - self.skipTest(f"{task} in {framework} has no default") + self.skipTest(f"{task} in pytorch has no default") # by default use first class auto_model_cls = relevant_auto_classes[0] @@ -646,14 +646,14 @@ def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equa revisions = [] tasks = [] for translation_pair in task_dict["default"]: - model_id, revision = task_dict["default"][translation_pair]["model"][framework] + model_id, revision = task_dict["default"][translation_pair]["model"] model_ids.append(model_id) revisions.append(revision) tasks.append(task + f"_{'_to_'.join(translation_pair)}") else: # normal case - non-translation pipeline - model_id, revision = task_dict["default"]["model"][framework] + model_id, revision = task_dict["default"]["model"] model_ids = [model_id] revisions = [revision] From 9a41b8d5d34e3640fdd49710b710f1f45cd667cd Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 12 Sep 2025 10:02:11 +0200 Subject: [PATCH 29/35] small fix --- tests/utils/test_configuration_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/utils/test_configuration_utils.py b/tests/utils/test_configuration_utils.py index 60e8703937cd..50b5b11db1b7 100644 --- a/tests/utils/test_configuration_utils.py +++ b/tests/utils/test_configuration_utils.py @@ -40,8 +40,6 @@ "output_attentions": True, "torchscript": True, "dtype": "float16", - "use_bfloat16": True, - "tf_legacy_loss": True, "pruned_heads": {"a": 1}, "tie_word_embeddings": False, "is_decoder": True, From a75f707adcff151734455b16ec22a29bc666fb0e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 18 Sep 2025 15:38:48 +0200 Subject: [PATCH 30/35] add converters back --- ...lbert_original_tf_checkpoint_to_pytorch.py | 62 +++ .../models/align/convert_align_tf_to_hf.py | 389 ++++++++++++++++++ ..._bert_original_tf_checkpoint_to_pytorch.py | 62 +++ 3 files changed, 513 insertions(+) create mode 100644 src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/align/convert_align_tf_to_hf.py create mode 100644 src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py diff --git a/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..df2a22610187 --- /dev/null +++ b/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ALBERT checkpoint.""" + +import argparse + +import torch + +from ...utils import logging +from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path): + # Initialise PyTorch model + config = AlbertConfig.from_json_file(albert_config_file) + print(f"Building PyTorch model from configuration: {config}") + model = AlbertForPreTraining(config) + + # Load weights from tf checkpoint + load_tf_weights_in_albert(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--albert_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained ALBERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/align/convert_align_tf_to_hf.py b/src/transformers/models/align/convert_align_tf_to_hf.py new file mode 100644 index 000000000000..74309a0d7076 --- /dev/null +++ b/src/transformers/models/align/convert_align_tf_to_hf.py @@ -0,0 +1,389 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ALIGN checkpoints from the original repository.""" + +import argparse +import os + +import align +import numpy as np +import requests +import tensorflow as tf +import torch +from PIL import Image +from tokenizer import Tokenizer + +from transformers import ( + AlignConfig, + AlignModel, + AlignProcessor, + BertConfig, + BertTokenizer, + EfficientNetConfig, + EfficientNetImageProcessor, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def preprocess(image): + image = tf.image.resize(image, (346, 346)) + image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289) + return image + + +def get_align_config(): + vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7") + vision_config.image_size = 289 + vision_config.hidden_dim = 640 + vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"} + vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1} + vision_config.depthwise_padding = [] + + text_config = BertConfig() + config = AlignConfig.from_text_vision_configs( + text_config=text_config, vision_config=vision_config, projection_dim=640 + ) + return config + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +def get_processor(): + image_processor = EfficientNetImageProcessor( + do_center_crop=True, + rescale_factor=1 / 127.5, + rescale_offset=True, + do_normalize=False, + include_top=False, + resample=Image.BILINEAR, + ) + tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") + tokenizer.model_max_length = 64 + processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer) + return processor + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def rename_keys(original_param_names): + # EfficientNet image encoder + block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")] + block_names = list(set(block_names)) + block_names = sorted(block_names) + num_blocks = len(block_names) + block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))} + + rename_keys = [] + rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight")) + rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight")) + rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias")) + rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean")) + rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var")) + + for b in block_names: + hf_b = block_name_mapping[b] + rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight")) + rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight")) + rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias")) + rename_keys.append( + (f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean") + ) + rename_keys.append( + (f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var") + ) + rename_keys.append( + (f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight") + ) + rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight")) + rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias")) + rename_keys.append( + (f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean") + ) + rename_keys.append( + (f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var") + ) + + rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight")) + rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias")) + rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight")) + rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias")) + rename_keys.append( + (f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight") + ) + rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight")) + rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias")) + rename_keys.append( + (f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean") + ) + rename_keys.append( + (f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var") + ) + + key_mapping = {} + for item in rename_keys: + if item[0] in original_param_names: + key_mapping[item[0]] = "vision_model." + item[1] + + # BERT text encoder + rename_keys = [] + old = "tf_bert_model/bert" + new = "text_model" + for i in range(12): + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0", + f"{new}.encoder.layer.{i}.attention.self.query.weight", + ) + ) + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/attention/self/query/bias:0", + f"{new}.encoder.layer.{i}.attention.self.query.bias", + ) + ) + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0", + f"{new}.encoder.layer.{i}.attention.self.key.weight", + ) + ) + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/attention/self/key/bias:0", + f"{new}.encoder.layer.{i}.attention.self.key.bias", + ) + ) + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0", + f"{new}.encoder.layer.{i}.attention.self.value.weight", + ) + ) + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/attention/self/value/bias:0", + f"{new}.encoder.layer.{i}.attention.self.value.bias", + ) + ) + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0", + f"{new}.encoder.layer.{i}.attention.output.dense.weight", + ) + ) + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0", + f"{new}.encoder.layer.{i}.attention.output.dense.bias", + ) + ) + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0", + f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight", + ) + ) + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0", + f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias", + ) + ) + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0", + f"{new}.encoder.layer.{i}.intermediate.dense.weight", + ) + ) + rename_keys.append( + ( + f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0", + f"{new}.encoder.layer.{i}.intermediate.dense.bias", + ) + ) + rename_keys.append( + (f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight") + ) + rename_keys.append( + (f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias") + ) + rename_keys.append( + (f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight") + ) + rename_keys.append( + (f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias") + ) + + rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight")) + rename_keys.append( + (f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight") + ) + rename_keys.append( + (f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight") + ) + rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight")) + rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias")) + + rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight")) + rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias")) + rename_keys.append(("dense/kernel:0", "text_projection.weight")) + rename_keys.append(("dense/bias:0", "text_projection.bias")) + rename_keys.append(("dense/bias:0", "text_projection.bias")) + rename_keys.append(("temperature:0", "temperature")) + + for item in rename_keys: + if item[0] in original_param_names: + key_mapping[item[0]] = item[1] + return key_mapping + + +def replace_params(hf_params, tf_params, key_mapping): + list(hf_params.keys()) + + for key, value in tf_params.items(): + if key not in key_mapping: + continue + + hf_key = key_mapping[key] + if "_conv" in key and "kernel" in key: + new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1) + elif "embeddings" in key: + new_hf_value = torch.from_numpy(value) + elif "depthwise_kernel" in key: + new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1) + elif "kernel" in key: + new_hf_value = torch.from_numpy(np.transpose(value)) + elif "temperature" in key: + new_hf_value = value + elif "bn/gamma" in key or "bn/beta" in key: + new_hf_value = torch.from_numpy(np.transpose(value)).squeeze() + else: + new_hf_value = torch.from_numpy(value) + + # Replace HF parameters with original TF model parameters + hf_params[hf_key].copy_(new_hf_value) + + +@torch.no_grad() +def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub): + """ + Copy/paste/tweak model's weights to our ALIGN structure. + """ + # Load original model + seq_length = 64 + tok = Tokenizer(seq_length) + original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size()) + original_model.compile() + original_model.load_weights(checkpoint_path) + + tf_params = original_model.trainable_variables + tf_non_train_params = original_model.non_trainable_variables + tf_params = {param.name: param.numpy() for param in tf_params} + for param in tf_non_train_params: + tf_params[param.name] = param.numpy() + tf_param_names = list(tf_params.keys()) + + # Load HuggingFace model + config = get_align_config() + hf_model = AlignModel(config).eval() + hf_params = hf_model.state_dict() + + # Create src-to-dst parameter name mapping dictionary + print("Converting parameters...") + key_mapping = rename_keys(tf_param_names) + replace_params(hf_params, tf_params, key_mapping) + + # Initialize processor + processor = get_processor() + inputs = processor( + images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt" + ) + + # HF model inference + hf_model.eval() + with torch.no_grad(): + outputs = hf_model(**inputs) + + hf_image_features = outputs.image_embeds.detach().numpy() + hf_text_features = outputs.text_embeds.detach().numpy() + + # Original model inference + original_model.trainable = False + tf_image_processor = EfficientNetImageProcessor( + do_center_crop=True, + do_rescale=False, + do_normalize=False, + include_top=False, + resample=Image.BILINEAR, + ) + image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"] + text = tok(tf.constant(["A picture of a cat"])) + + image_features = original_model.image_encoder(image, training=False) + text_features = original_model.text_encoder(text, training=False) + + image_features = tf.nn.l2_normalize(image_features, axis=-1) + text_features = tf.nn.l2_normalize(text_features, axis=-1) + + # Check whether original and HF model outputs match -> np.allclose + if not np.allclose(image_features, hf_image_features, atol=1e-3): + raise ValueError("The predicted image features are not the same.") + if not np.allclose(text_features, hf_text_features, atol=1e-3): + raise ValueError("The predicted text features are not the same.") + print("Model outputs match!") + + if save_model: + # Create folder to save model + if not os.path.isdir(pytorch_dump_folder_path): + os.mkdir(pytorch_dump_folder_path) + # Save converted model and image processor + hf_model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + # Push model and image processor to hub + print("Pushing converted ALIGN to the hub...") + processor.push_to_hub("align-base") + hf_model.push_to_hub("align-base") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_path", + default="./weights/model-weights", + type=str, + help="Path to the pretrained TF ALIGN checkpoint.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default="hf_model", + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument("--save_model", action="store_true", help="Save model to local") + parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub") + + args = parser.parse_args() + convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub) diff --git a/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..be904ddd7e6c --- /dev/null +++ b/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert BERT checkpoint.""" + +import argparse + +import torch + +from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): + # Initialise PyTorch model + config = BertConfig.from_json_file(bert_config_file) + print(f"Building PyTorch model from configuration: {config}") + model = BertForPreTraining(config) + + # Load weights from tf checkpoint + load_tf_weights_in_bert(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--bert_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) From 0b008577f01357af92e0a402c3dfddd3e5e340a9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 18 Sep 2025 15:56:44 +0200 Subject: [PATCH 31/35] post rebase --- src/transformers/models/janus/modular_janus.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index dcd5c1e1e730..ef99aaf45680 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -1546,10 +1546,8 @@ def preprocess( return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. @@ -1579,10 +1577,7 @@ def preprocess( images = make_flat_list_of_images(images) if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") validate_preprocess_arguments( do_rescale=do_rescale, From d6692aa09ddb1630e27be80b042064292b52fa18 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 18 Sep 2025 16:05:48 +0200 Subject: [PATCH 32/35] latest qwen --- src/transformers/models/qwen3_vl/modular_qwen3_vl.py | 2 -- src/transformers/models/qwen3_vl/processing_qwen3_vl.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 7a2fa852739e..0a97489e285f 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -1337,10 +1337,8 @@ def __call__( tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py index cac82e738f39..7535d28a4ad0 100644 --- a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py @@ -137,10 +137,8 @@ def __call__( tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: From ae33acfa33bfbdc1397007f78a3505de67459a3a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 18 Sep 2025 17:03:06 +0200 Subject: [PATCH 33/35] add back all converters --- ...bert_original_tf2_checkpoint_to_pytorch.py | 246 +++++++++++ ..._bert_original_tf_checkpoint_to_pytorch.py | 0 ...ping_original_tf2_checkpoint_to_pytorch.py | 188 ++++++++ ...gbird_original_tf_checkpoint_to_pytorch.py | 69 +++ .../convert_bigbird_pegasus_tf_to_pytorch.py | 169 ++++++++ ..._byt5_original_tf_checkpoint_to_pytorch.py | 59 +++ ...anine_original_tf_checkpoint_to_pytorch.py | 65 +++ ...ginal_tf1_checkpoint_to_pytorch_and_tf2.py | 57 +++ ...convert_gptsan_tf_checkpoint_to_pytorch.py | 181 ++++++++ ...fo_xl_original_tf_checkpoint_to_pytorch.py | 121 ++++++ ...ectra_original_tf_checkpoint_to_pytorch.py | 79 ++++ ...net_original_flax_checkpoint_to_pytorch.py | 156 +++++++ ...unnel_original_tf_checkpoint_to_pytorch.py | 64 +++ ..._gpt2_original_tf_checkpoint_to_pytorch.py | 68 +++ .../convert_gpt_neo_mesh_tf_to_pytorch.py | 71 +++ ...onvert_imagegpt_original_tf2_to_pytorch.py | 71 +++ ...xmert_original_tf_checkpoint_to_pytorch.py | 59 +++ ...ebert_original_tf_checkpoint_to_pytorch.py | 58 +++ ...nvert_original_tf_checkpoint_to_pytorch.py | 141 ++++++ ...nvert_original_tf_checkpoint_to_pytorch.py | 177 ++++++++ ..._myt5_original_tf_checkpoint_to_pytorch.py | 60 +++ ...penai_original_tf_checkpoint_to_pytorch.py | 74 ++++ .../convert_owlvit_original_flax_to_hf.py | 406 ++++++++++++++++++ .../pegasus/convert_pegasus_tf_to_pytorch.py | 130 ++++++ ...onvert_rembert_tf_checkpoint_to_pytorch.py | 62 +++ ...ormer_original_tf_checkpoint_to_pytorch.py | 62 +++ .../convert_s2t_fairseq_to_tfms.py | 121 ++++++ ...ers_original_flax_checkpoint_to_pytorch.py | 203 +++++++++ ...rt_t5_original_tf_checkpoint_to_pytorch.py | 59 +++ ...tapas_original_tf_checkpoint_to_pytorch.py | 137 ++++++ .../vivit/convert_vivit_flax_to_pytorch.py | 231 ++++++++++ ...xlnet_original_tf_checkpoint_to_pytorch.py | 113 +++++ 32 files changed, 3757 insertions(+) create mode 100644 src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py mode change 100644 => 100755 src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/bigbird_pegasus/convert_bigbird_pegasus_tf_to_pytorch.py create mode 100755 src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py create mode 100644 src/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py create mode 100755 src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py create mode 100755 src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py create mode 100644 src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py create mode 100755 src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py create mode 100755 src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py create mode 100644 src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py create mode 100755 src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py create mode 100755 src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py create mode 100644 src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py create mode 100755 src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/vivit/convert_vivit_flax_to_pytorch.py create mode 100755 src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py diff --git a/src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py b/src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..9dfd8da474e3 --- /dev/null +++ b/src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py @@ -0,0 +1,246 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now +deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert + +TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert +weight names to the original names, so the model can be imported with Huggingface/transformer. + +You may adapt this script to include classification/MLM/NSP/etc. heads. + +Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0). + Models trained with never versions are not compatible with this script. +""" + +import argparse +import os +import re + +import tensorflow as tf +import torch + +from transformers import BertConfig, BertModel +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def load_tf2_weights_in_bert(model, tf_checkpoint_path, config): + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + layer_depth = [] + for full_name, shape in init_vars: + # logger.info(f"Loading TF weight {name} with shape {shape}") + name = full_name.split("/") + if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]: + logger.info(f"Skipping non-model layer {full_name}") + continue + if "optimizer" in full_name: + logger.info(f"Skipping optimization layer {full_name}") + continue + if name[0] == "model": + # ignore initial 'model' + name = name[1:] + # figure out how many levels deep the name is + depth = 0 + for _name in name: + if _name.startswith("layer_with_weights"): + depth += 1 + else: + break + layer_depth.append(depth) + # read data + array = tf.train.load_variable(tf_path, full_name) + names.append("/".join(name)) + arrays.append(array) + logger.info(f"Read a total of {len(arrays):,} layers") + + # Sanity check + if len(set(layer_depth)) != 1: + raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})") + layer_depth = list(set(layer_depth))[0] + if layer_depth != 1: + raise ValueError( + "The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP" + " heads." + ) + + # convert layers + logger.info("Converting weights...") + for full_name, array in zip(names, arrays): + name = full_name.split("/") + pointer = model + trace = [] + for i, m_name in enumerate(name): + if m_name == ".ATTRIBUTES": + # variable names end with .ATTRIBUTES/VARIABLE_VALUE + break + if m_name.startswith("layer_with_weights"): + layer_num = int(m_name.split("-")[-1]) + if layer_num <= 2: + # embedding layers + # layer_num 0: word_embeddings + # layer_num 1: position_embeddings + # layer_num 2: token_type_embeddings + continue + elif layer_num == 3: + # embedding LayerNorm + trace.extend(["embeddings", "LayerNorm"]) + pointer = getattr(pointer, "embeddings") + pointer = getattr(pointer, "LayerNorm") + elif layer_num > 3 and layer_num < config.num_hidden_layers + 4: + # encoder layers + trace.extend(["encoder", "layer", str(layer_num - 4)]) + pointer = getattr(pointer, "encoder") + pointer = getattr(pointer, "layer") + pointer = pointer[layer_num - 4] + elif layer_num == config.num_hidden_layers + 4: + # pooler layer + trace.extend(["pooler", "dense"]) + pointer = getattr(pointer, "pooler") + pointer = getattr(pointer, "dense") + elif m_name == "embeddings": + trace.append("embeddings") + pointer = getattr(pointer, "embeddings") + if layer_num == 0: + trace.append("word_embeddings") + pointer = getattr(pointer, "word_embeddings") + elif layer_num == 1: + trace.append("position_embeddings") + pointer = getattr(pointer, "position_embeddings") + elif layer_num == 2: + trace.append("token_type_embeddings") + pointer = getattr(pointer, "token_type_embeddings") + else: + raise ValueError(f"Unknown embedding layer with name {full_name}") + trace.append("weight") + pointer = getattr(pointer, "weight") + elif m_name == "_attention_layer": + # self-attention layer + trace.extend(["attention", "self"]) + pointer = getattr(pointer, "attention") + pointer = getattr(pointer, "self") + elif m_name == "_attention_layer_norm": + # output attention norm + trace.extend(["attention", "output", "LayerNorm"]) + pointer = getattr(pointer, "attention") + pointer = getattr(pointer, "output") + pointer = getattr(pointer, "LayerNorm") + elif m_name == "_attention_output_dense": + # output attention dense + trace.extend(["attention", "output", "dense"]) + pointer = getattr(pointer, "attention") + pointer = getattr(pointer, "output") + pointer = getattr(pointer, "dense") + elif m_name == "_output_dense": + # output dense + trace.extend(["output", "dense"]) + pointer = getattr(pointer, "output") + pointer = getattr(pointer, "dense") + elif m_name == "_output_layer_norm": + # output dense + trace.extend(["output", "LayerNorm"]) + pointer = getattr(pointer, "output") + pointer = getattr(pointer, "LayerNorm") + elif m_name == "_key_dense": + # attention key + trace.append("key") + pointer = getattr(pointer, "key") + elif m_name == "_query_dense": + # attention query + trace.append("query") + pointer = getattr(pointer, "query") + elif m_name == "_value_dense": + # attention value + trace.append("value") + pointer = getattr(pointer, "value") + elif m_name == "_intermediate_dense": + # attention intermediate dense + trace.extend(["intermediate", "dense"]) + pointer = getattr(pointer, "intermediate") + pointer = getattr(pointer, "dense") + elif m_name == "_output_layer_norm": + # output layer norm + trace.append("output") + pointer = getattr(pointer, "output") + # weights & biases + elif m_name in ["bias", "beta"]: + trace.append("bias") + pointer = getattr(pointer, "bias") + elif m_name in ["kernel", "gamma"]: + trace.append("weight") + pointer = getattr(pointer, "weight") + else: + logger.warning(f"Ignored {m_name}") + # for certain layers reshape is necessary + trace = ".".join(trace) + if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match( + r"(\S+)\.attention\.output\.dense\.weight", trace + ): + array = array.reshape(pointer.data.shape) + if "kernel" in full_name: + array = array.transpose() + if pointer.shape == array.shape: + pointer.data = torch.from_numpy(array) + else: + raise ValueError( + f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:" + f" {array.shape}" + ) + logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}") + return model + + +def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path): + # Instantiate model + logger.info(f"Loading model based on config from {config_path}...") + config = BertConfig.from_json_file(config_path) + model = BertModel(config) + + # Load weights from checkpoint + logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...") + load_tf2_weights_in_bert(model, tf_checkpoint_path, config) + + # Save pytorch-model + logger.info(f"Saving PyTorch model to {pytorch_dump_path}...") + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path." + ) + parser.add_argument( + "--bert_config_file", + type=str, + required=True, + help="The config json file corresponding to the BERT model. This specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", + type=str, + required=True, + help="Path to the output PyTorch model (must include filename).", + ) + args = parser.parse_args() + convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py old mode 100644 new mode 100755 diff --git a/src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py b/src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..a7832a53d55d --- /dev/null +++ b/src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py @@ -0,0 +1,188 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT +model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository: + +https://github.com/tensorflow/models/tree/master/official/projects/token_dropping +""" + +import argparse + +import tensorflow as tf +import torch + +from transformers import BertConfig, BertForMaskedLM +from transformers.models.bert.modeling_bert import ( + BertIntermediate, + BertLayer, + BertOutput, + BertPooler, + BertSelfAttention, + BertSelfOutput, +) +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str): + def get_masked_lm_array(name: str): + full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE" + array = tf.train.load_variable(tf_checkpoint_path, full_name) + + if "kernel" in name: + array = array.transpose() + + return torch.from_numpy(array) + + def get_encoder_array(name: str): + full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE" + array = tf.train.load_variable(tf_checkpoint_path, full_name) + + if "kernel" in name: + array = array.transpose() + + return torch.from_numpy(array) + + def get_encoder_layer_array(layer_index: int, name: str): + full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE" + array = tf.train.load_variable(tf_checkpoint_path, full_name) + + if "kernel" in name: + array = array.transpose() + + return torch.from_numpy(array) + + def get_encoder_attention_layer_array(layer_index: int, name: str, original_shape): + full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE" + array = tf.train.load_variable(tf_checkpoint_path, full_name) + array = array.reshape(original_shape) + + if "kernel" in name: + array = array.transpose() + + return torch.from_numpy(array) + + print(f"Loading model based on config from {config_path}...") + config = BertConfig.from_json_file(config_path) + model = BertForMaskedLM(config) + + # Layers + for layer_index in range(0, config.num_hidden_layers): + layer: BertLayer = model.bert.encoder.layer[layer_index] + + # Self-attention + self_attn: BertSelfAttention = layer.attention.self + + self_attn.query.weight.data = get_encoder_attention_layer_array( + layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape + ) + self_attn.query.bias.data = get_encoder_attention_layer_array( + layer_index, "_query_dense/bias", self_attn.query.bias.data.shape + ) + self_attn.key.weight.data = get_encoder_attention_layer_array( + layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape + ) + self_attn.key.bias.data = get_encoder_attention_layer_array( + layer_index, "_key_dense/bias", self_attn.key.bias.data.shape + ) + self_attn.value.weight.data = get_encoder_attention_layer_array( + layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape + ) + self_attn.value.bias.data = get_encoder_attention_layer_array( + layer_index, "_value_dense/bias", self_attn.value.bias.data.shape + ) + + # Self-attention Output + self_output: BertSelfOutput = layer.attention.output + + self_output.dense.weight.data = get_encoder_attention_layer_array( + layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape + ) + self_output.dense.bias.data = get_encoder_attention_layer_array( + layer_index, "_output_dense/bias", self_output.dense.bias.data.shape + ) + + self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma") + self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta") + + # Intermediate + intermediate: BertIntermediate = layer.intermediate + + intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel") + intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias") + + # Output + bert_output: BertOutput = layer.output + + bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel") + bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias") + + bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma") + bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta") + + # Embeddings + model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings") + model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings") + model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma") + model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta") + + # LM Head + lm_head = model.cls.predictions.transform + + lm_head.dense.weight.data = get_masked_lm_array("dense/kernel") + lm_head.dense.bias.data = get_masked_lm_array("dense/bias") + + lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma") + lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta") + + model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table") + + # Pooling + model.bert.pooler = BertPooler(config=config) + model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel") + model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias") + + # Export final model + model.save_pretrained(pytorch_dump_path) + + # Integration test - should load without any errors ;) + new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path) + print(new_model.eval()) + + print("Model conversion was done successfully!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path." + ) + parser.add_argument( + "--bert_config_file", + type=str, + required=True, + help="The config json file corresponding to the BERT model. This specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", + type=str, + required=True, + help="Path to the output PyTorch model.", + ) + args = parser.parse_args() + convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..0b8e6590f937 --- /dev/null +++ b/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert BigBird checkpoint.""" + +import argparse + +from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa): + # Initialise PyTorch model + config = BigBirdConfig.from_json_file(big_bird_config_file) + print(f"Building PyTorch model from configuration: {config}") + + if is_trivia_qa: + model = BigBirdForQuestionAnswering(config) + else: + model = BigBirdForPreTraining(config) + + # Load weights from tf checkpoint + load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--big_bird_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch( + args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa + ) diff --git a/src/transformers/models/bigbird_pegasus/convert_bigbird_pegasus_tf_to_pytorch.py b/src/transformers/models/bigbird_pegasus/convert_bigbird_pegasus_tf_to_pytorch.py new file mode 100644 index 000000000000..d0a312ebc11f --- /dev/null +++ b/src/transformers/models/bigbird_pegasus/convert_bigbird_pegasus_tf_to_pytorch.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import tensorflow as tf +import torch +from tqdm import tqdm + +from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration + + +INIT_COMMON = [ + # tf -> hf + ("/", "."), + ("layer_", "layers."), + ("kernel", "weight"), + ("beta", "bias"), + ("gamma", "weight"), + ("pegasus", "model"), +] +END_COMMON = [ + (".output.dense", ".fc2"), + ("intermediate.LayerNorm", "final_layer_norm"), + ("intermediate.dense", "fc1"), +] + +DECODER_PATTERNS = ( + INIT_COMMON + + [ + ("attention.self.LayerNorm", "self_attn_layer_norm"), + ("attention.output.dense", "self_attn.out_proj"), + ("attention.self", "self_attn"), + ("attention.encdec.LayerNorm", "encoder_attn_layer_norm"), + ("attention.encdec_output.dense", "encoder_attn.out_proj"), + ("attention.encdec", "encoder_attn"), + ("key", "k_proj"), + ("value", "v_proj"), + ("query", "q_proj"), + ("decoder.LayerNorm", "decoder.layernorm_embedding"), + ] + + END_COMMON +) + +REMAINING_PATTERNS = ( + INIT_COMMON + + [ + ("embeddings.word_embeddings", "shared.weight"), + ("embeddings.position_embeddings", "embed_positions.weight"), + ("attention.self.LayerNorm", "self_attn_layer_norm"), + ("attention.output.dense", "self_attn.output"), + ("attention.self", "self_attn.self"), + ("encoder.LayerNorm", "encoder.layernorm_embedding"), + ] + + END_COMMON +) + +KEYS_TO_IGNORE = [ + "encdec/key/bias", + "encdec/query/bias", + "encdec/value/bias", + "self/key/bias", + "self/query/bias", + "self/value/bias", + "encdec_output/dense/bias", + "attention/output/dense/bias", +] + + +def rename_state_dict_key(k, patterns): + for tf_name, hf_name in patterns: + k = k.replace(tf_name, hf_name) + return k + + +def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration: + cfg = BigBirdPegasusConfig(**config_update) + torch_model = BigBirdPegasusForConditionalGeneration(cfg) + state_dict = torch_model.state_dict() + mapping = {} + + # separating decoder weights + decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")} + remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")} + + for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"): + conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE] + if any(conditions): + continue + patterns = DECODER_PATTERNS + new_k = rename_state_dict_key(k, patterns) + if new_k not in state_dict: + raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") + if any(i in k for i in ["dense", "query", "key", "value"]): + v = v.T + mapping[new_k] = torch.from_numpy(v) + assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}" + + for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"): + conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE] + if any(conditions): + continue + patterns = REMAINING_PATTERNS + new_k = rename_state_dict_key(k, patterns) + if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings": + raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") + if any(i in k for i in ["dense", "query", "key", "value"]): + v = v.T + mapping[new_k] = torch.from_numpy(v) + if k != "pegasus/embeddings/position_embeddings": + assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}" + + mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"] + mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight") + missing, extra = torch_model.load_state_dict(mapping, strict=False) + unexpected_missing = [ + k + for k in missing + if k + not in [ + "final_logits_bias", + "model.encoder.embed_tokens.weight", + "model.decoder.embed_tokens.weight", + "lm_head.weight", + ] + ] + assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}" + assert extra == [], f"no matches found for the following tf keys {extra}" + return torch_model + + +def get_tf_weights_as_numpy(path) -> dict: + init_vars = tf.train.list_variables(path) + tf_weights = {} + ignore_name = ["global_step"] + for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"): + skip_key = any(pat in name for pat in ignore_name) + if skip_key: + continue + array = tf.train.load_variable(path, name) + tf_weights[name] = array + return tf_weights + + +def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict): + tf_weights = get_tf_weights_as_numpy(ckpt_path) + torch_model = convert_bigbird_pegasus(tf_weights, config_update) + torch_model.save_pretrained(save_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables") + parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + config_update = {} + convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update) diff --git a/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..9b1b15857cea --- /dev/null +++ b/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2018 The T5 authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert T5 checkpoint.""" + +import argparse + +from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = T5Config.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + model = T5ForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_tf_weights_in_t5(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..45dcdb290333 --- /dev/null +++ b/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert CANINE checkpoint.""" + +import argparse + +from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path): + # Initialize PyTorch model + config = CanineConfig() + model = CanineModel(config) + model.eval() + + print(f"Building PyTorch model from configuration: {config}") + + # Load weights from tf checkpoint + load_tf_weights_in_canine(model, config, tf_checkpoint_path) + + # Save pytorch-model (weights and configuration) + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + # Save tokenizer files + tokenizer = CanineTokenizer() + print(f"Save tokenizer files to {pytorch_dump_path}") + tokenizer.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the TensorFlow checkpoint. Should end with model.ckpt", + ) + parser.add_argument( + "--pytorch_dump_path", + default=None, + type=str, + required=True, + help="Path to a folder where the PyTorch model will be placed.", + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.pytorch_dump_path) diff --git a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py new file mode 100644 index 000000000000..3d4ff779874b --- /dev/null +++ b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py @@ -0,0 +1,57 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ConvBERT checkpoint.""" + +import argparse + +from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path): + conf = ConvBertConfig.from_json_file(convbert_config_file) + model = ConvBertModel(conf) + + model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path) + model.save_pretrained(pytorch_dump_path) + + tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True) + tf_model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--convbert_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained ConvBERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py b/src/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..76b9c9cf328c --- /dev/null +++ b/src/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert GPTSANJapanese checkpoints from the original repository to pytorch model.""" + +import argparse +import json +import os +from collections import OrderedDict + +import numpy as np +import tensorflow as tf +import torch + + +def convert_tf_gptsan_to_pt(args): + parameter_file = os.path.join(args.tf_model_dir, "parameters.json") + params = json.loads(open(parameter_file).read()) + if not params: + raise ValueError( + f"It seems that the json file at {parameter_file} is empty. Make sure you have a correct json file." + ) + if not args.output.endswith(".pt"): + args.output = args.output + ".pt" + new_state = OrderedDict() + with tf.device("/CPU:0"): + reader = tf.train.load_checkpoint(args.tf_model_dir) + shapes = reader.get_variable_to_shape_map() + for key_name in shapes: + vnp = reader.get_tensor(key_name).astype(np.float16) + if key_name.endswith("/adam_m") or key_name.endswith("/adam_v"): + continue + if key_name.startswith("pasts/"): + if key_name.startswith("pasts/mlp"): + player = int(key_name[9]) + elif key_name.startswith("pasts/out"): + player = 8 + name = "model.sqout.%d.weight" % (player * 2) # enter to nn.Sequential with Tanh, so 2 at a time + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/moe"): + player = int(key_name[9:].split("/")[0]) + if key_name.endswith("/switch_gating/kernel"): + name = "model.blocks.%d.feed_forward.mlp.router.classifier.weight" % player + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.endswith("/softmlp/kernel"): + name = "model.blocks.%d.feed_forward.soft_bypass_mlp.weight" % player + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.endswith("/wo/kernel") or key_name.endswith("/wi/kernel"): + nlayer = key_name[-9:-7] + for i in range(16): + name = "model.blocks.%d.feed_forward.mlp.experts.expert_%d.%s.weight" % (player, i, nlayer) + state = ( + vnp[i].transpose([1, 0]).copy() + ) # In Mesh-Tensorflow, it is one array, so it is divided + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/mlp"): + player = int(key_name[9:].split("/")[0]) + if key_name.endswith("/p1/kernel"): + name = "model.blocks.%d.feed_forward.mlp.wi.weight" % player + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.endswith("/p1/bias"): + name = "model.blocks.%d.feed_forward.mlp.wi.bias" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif key_name.endswith("/p2/kernel"): + name = "model.blocks.%d.feed_forward.mlp.wo.weight" % player + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.endswith("/p2/bias"): + name = "model.blocks.%d.feed_forward.mlp.wo.bias" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/ln"): + player = int(key_name[8:].split("/")[0]) + if key_name.endswith("/b"): + name = "model.blocks.%d.feed_forward.norm.bias" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif key_name.endswith("/g"): + name = "model.blocks.%d.feed_forward.norm.weight" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/att"): + player = int(key_name[9:].split("/")[0]) + if key_name.endswith("/qkv/kernel"): + state = vnp.copy() # Compute same dimension as Mesh-tensorflow using einsum + state_q = state[:, 0, :, :] + state_k = state[:, 1, :, :] + state_v = state[:, 2, :, :] + state_q = ( + state_q.reshape([state_q.shape[0], state_q.shape[1] * state_q.shape[2]]) + .transpose([1, 0]) + .copy() + ) # Mesh-Tensorflow is a diagonal matrix + state_k = ( + state_k.reshape([state_k.shape[0], state_k.shape[1] * state_k.shape[2]]) + .transpose([1, 0]) + .copy() + ) # Mesh-Tensorflow is a diagonal matrix + state_v = ( + state_v.reshape([state_v.shape[0], state_v.shape[1] * state_v.shape[2]]) + .transpose([1, 0]) + .copy() + ) # Mesh-Tensorflow is a diagonal matrix + name = "model.blocks.%d.self_attn.self_attn.q_proj.weight" % player + new_state[name] = torch.tensor(state_q) + name = "model.blocks.%d.self_attn.self_attn.k_proj.weight" % player + new_state[name] = torch.tensor(state_k) + name = "model.blocks.%d.self_attn.self_attn.v_proj.weight" % player + new_state[name] = torch.tensor(state_v) + elif key_name.endswith("/o/kernel"): + name = "model.blocks.%d.self_attn.self_attn.out_proj.weight" % player + state = ( + vnp.reshape([vnp.shape[0] * vnp.shape[1], vnp.shape[2]]).transpose([1, 0]).copy() + ) # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/an"): + player = int(key_name[8:].split("/")[0]) + if key_name.endswith("/b"): + name = "model.blocks.%d.self_attn.norm.bias" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif key_name.endswith("/g"): + name = "model.blocks.%d.self_attn.norm.weight" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif ( + key_name.startswith("model/wte") + or key_name.startswith("model/wpe") + or key_name.startswith("model/ete") + ): + nlayer = {"wte": "embed_tokens", "wpe": "position_embeddings", "ete": "extra_position_embeddings"}[ + key_name[-3:] + ] + name = "model.%s.weight" % nlayer + state = vnp.copy() # same in embedded + new_state[name] = torch.tensor(state) + if key_name.startswith("model/wte"): + name = "lm_head.weight" + state = vnp.copy() # same in embedded + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/wob"): + name = "final_logits_bias" + state = vnp.copy() # same in embedded + state = state.reshape((1, -1)) + new_state[name] = torch.tensor(state) + elif key_name == "model/dense/kernel": + name = "model.last_project.weight" + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name == "model/dense_1/bias": + name = "model.last_project.bias" + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + torch.save(new_state, args.output) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="model converter.", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--tf_model_dir", metavar="PATH", type=str, required=True, help="import model") + parser.add_argument("--output", metavar="PATH", type=str, required=True, help="output model") + args = parser.parse_args() + convert_tf_gptsan_to_pt(args) diff --git a/src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..2c7b687c4d98 --- /dev/null +++ b/src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Transformer XL checkpoint and datasets.""" + +import argparse +import os +import pickle +import sys + +import torch + +from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl +from transformers.models.deprecated.transfo_xl import tokenization_transfo_xl as data_utils +from transformers.models.deprecated.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES +from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging + + +logging.set_verbosity_info() + +# We do this to be able to load python 2 datasets pickles +# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 +data_utils.Vocab = data_utils.TransfoXLTokenizer +data_utils.Corpus = data_utils.TransfoXLCorpus +sys.modules["data_utils"] = data_utils +sys.modules["vocabulary"] = data_utils + + +def convert_transfo_xl_checkpoint_to_pytorch( + tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file +): + if transfo_xl_dataset_file: + # Convert a pre-processed corpus (see original TensorFlow repo) + with open(transfo_xl_dataset_file, "rb") as fp: + corpus = pickle.load(fp, encoding="latin1") + # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) + pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["pretrained_vocab_file"] + print(f"Save vocabulary to {pytorch_vocab_dump_path}") + corpus_vocab_dict = corpus.vocab.__dict__ + torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) + + corpus_dict_no_vocab = corpus.__dict__ + corpus_dict_no_vocab.pop("vocab", None) + pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME + print(f"Save dataset to {pytorch_dataset_dump_path}") + torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) + + if tf_checkpoint_path: + # Convert a pre-trained TensorFlow model + config_path = os.path.abspath(transfo_xl_config_file) + tf_path = os.path.abspath(tf_checkpoint_path) + + print(f"Converting Transformer XL checkpoint from {tf_path} with config at {config_path}.") + # Initialise PyTorch model + if transfo_xl_config_file == "": + config = TransfoXLConfig() + else: + config = TransfoXLConfig.from_json_file(transfo_xl_config_file) + print(f"Building PyTorch model from configuration: {config}") + model = TransfoXLLMHeadModel(config) + + model = load_tf_weights_in_transfo_xl(model, config, tf_path) + # Save pytorch-model + pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) + pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) + print(f"Save PyTorch model to {os.path.abspath(pytorch_weights_dump_path)}") + torch.save(model.state_dict(), pytorch_weights_dump_path) + print(f"Save configuration file to {os.path.abspath(pytorch_config_dump_path)}") + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the folder to store the PyTorch model or dataset/vocab.", + ) + parser.add_argument( + "--tf_checkpoint_path", + default="", + type=str, + help="An optional path to a TensorFlow checkpoint path to be converted.", + ) + parser.add_argument( + "--transfo_xl_config_file", + default="", + type=str, + help=( + "An optional config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--transfo_xl_dataset_file", + default="", + type=str, + help="An optional dataset file to be converted in a vocabulary.\n" + "Given the files are in the pickle format, please be wary of passing it files you trust.", + ) + args = parser.parse_args() + convert_transfo_xl_checkpoint_to_pytorch( + args.tf_checkpoint_path, + args.transfo_xl_config_file, + args.pytorch_dump_folder_path, + args.transfo_xl_dataset_file, + ) diff --git a/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..b0abc30cd758 --- /dev/null +++ b/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ELECTRA checkpoint.""" + +import argparse + +import torch + +from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator): + # Initialise PyTorch model + config = ElectraConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + + if discriminator_or_generator == "discriminator": + model = ElectraForPreTraining(config) + elif discriminator_or_generator == "generator": + model = ElectraForMaskedLM(config) + else: + raise ValueError("The discriminator_or_generator argument should be either 'discriminator' or 'generator'") + + # Load weights from tf checkpoint + load_tf_weights_in_electra( + model, config, tf_checkpoint_path, discriminator_or_generator=discriminator_or_generator + ) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--discriminator_or_generator", + default=None, + type=str, + required=True, + help=( + "Whether to export the generator or the discriminator. Should be a string, either 'discriminator' or " + "'generator'." + ), + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch( + args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.discriminator_or_generator + ) diff --git a/src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..71660354db14 --- /dev/null +++ b/src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert FNet checkpoint.""" + +import argparse + +import torch +from flax.training.checkpoints import restore_checkpoint + +from transformers import FNetConfig, FNetForPreTraining +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, fnet_config_file, save_path): + # Initialise PyTorch model + config = FNetConfig.from_json_file(fnet_config_file) + print(f"Building PyTorch model from configuration: {config}") + fnet_pretraining_model = FNetForPreTraining(config) + + checkpoint_dict = restore_checkpoint(flax_checkpoint_path, None) + pretrained_model_params = checkpoint_dict["target"] + + # Embeddings + # Position IDs + state_dict = fnet_pretraining_model.state_dict() + + position_ids = state_dict["fnet.embeddings.position_ids"] + new_state_dict = {"fnet.embeddings.position_ids": position_ids} + # Embedding Layers + new_state_dict["fnet.embeddings.word_embeddings.weight"] = torch.tensor( + pretrained_model_params["encoder"]["embedder"]["word"]["embedding"] + ) + new_state_dict["fnet.embeddings.position_embeddings.weight"] = torch.tensor( + pretrained_model_params["encoder"]["embedder"]["position"]["embedding"][0] + ) + new_state_dict["fnet.embeddings.token_type_embeddings.weight"] = torch.tensor( + pretrained_model_params["encoder"]["embedder"]["type"]["embedding"] + ) + new_state_dict["fnet.embeddings.projection.weight"] = torch.tensor( + pretrained_model_params["encoder"]["embedder"]["hidden_mapping_in"]["kernel"] + ).T + new_state_dict["fnet.embeddings.projection.bias"] = torch.tensor( + pretrained_model_params["encoder"]["embedder"]["hidden_mapping_in"]["bias"] + ) + new_state_dict["fnet.embeddings.LayerNorm.weight"] = torch.tensor( + pretrained_model_params["encoder"]["embedder"]["layer_norm"]["scale"] + ) + new_state_dict["fnet.embeddings.LayerNorm.bias"] = torch.tensor( + pretrained_model_params["encoder"]["embedder"]["layer_norm"]["bias"] + ) + + # Encoder Layers + for layer in range(config.num_hidden_layers): + new_state_dict[f"fnet.encoder.layer.{layer}.fourier.output.LayerNorm.weight"] = torch.tensor( + pretrained_model_params["encoder"][f"encoder_{layer}"]["mixing_layer_norm"]["scale"] + ) + new_state_dict[f"fnet.encoder.layer.{layer}.fourier.output.LayerNorm.bias"] = torch.tensor( + pretrained_model_params["encoder"][f"encoder_{layer}"]["mixing_layer_norm"]["bias"] + ) + + new_state_dict[f"fnet.encoder.layer.{layer}.intermediate.dense.weight"] = torch.tensor( + pretrained_model_params["encoder"][f"feed_forward_{layer}"]["intermediate"]["kernel"] + ).T + new_state_dict[f"fnet.encoder.layer.{layer}.intermediate.dense.bias"] = torch.tensor( + pretrained_model_params["encoder"][f"feed_forward_{layer}"]["intermediate"]["bias"] + ) + + new_state_dict[f"fnet.encoder.layer.{layer}.output.dense.weight"] = torch.tensor( + pretrained_model_params["encoder"][f"feed_forward_{layer}"]["output"]["kernel"] + ).T + new_state_dict[f"fnet.encoder.layer.{layer}.output.dense.bias"] = torch.tensor( + pretrained_model_params["encoder"][f"feed_forward_{layer}"]["output"]["bias"] + ) + + new_state_dict[f"fnet.encoder.layer.{layer}.output.LayerNorm.weight"] = torch.tensor( + pretrained_model_params["encoder"][f"encoder_{layer}"]["output_layer_norm"]["scale"] + ) + new_state_dict[f"fnet.encoder.layer.{layer}.output.LayerNorm.bias"] = torch.tensor( + pretrained_model_params["encoder"][f"encoder_{layer}"]["output_layer_norm"]["bias"] + ) + + # Pooler Layers + new_state_dict["fnet.pooler.dense.weight"] = torch.tensor(pretrained_model_params["encoder"]["pooler"]["kernel"]).T + new_state_dict["fnet.pooler.dense.bias"] = torch.tensor(pretrained_model_params["encoder"]["pooler"]["bias"]) + + # Masked LM Layers + new_state_dict["cls.predictions.transform.dense.weight"] = torch.tensor( + pretrained_model_params["predictions_dense"]["kernel"] + ).T + new_state_dict["cls.predictions.transform.dense.bias"] = torch.tensor( + pretrained_model_params["predictions_dense"]["bias"] + ) + new_state_dict["cls.predictions.transform.LayerNorm.weight"] = torch.tensor( + pretrained_model_params["predictions_layer_norm"]["scale"] + ) + new_state_dict["cls.predictions.transform.LayerNorm.bias"] = torch.tensor( + pretrained_model_params["predictions_layer_norm"]["bias"] + ) + new_state_dict["cls.predictions.decoder.weight"] = torch.tensor( + pretrained_model_params["encoder"]["embedder"]["word"]["embedding"] + ) + new_state_dict["cls.predictions.decoder.bias"] = torch.tensor( + pretrained_model_params["predictions_output"]["output_bias"] + ) + new_state_dict["cls.predictions.bias"] = torch.tensor(pretrained_model_params["predictions_output"]["output_bias"]) + + # Seq Relationship Layers + new_state_dict["cls.seq_relationship.weight"] = torch.tensor( + pretrained_model_params["classification"]["output_kernel"] + ) + new_state_dict["cls.seq_relationship.bias"] = torch.tensor( + pretrained_model_params["classification"]["output_bias"] + ) + + # Load State Dict + fnet_pretraining_model.load_state_dict(new_state_dict) + + # Save PreTrained + print(f"Saving pretrained model to {save_path}") + fnet_pretraining_model.save_pretrained(save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--flax_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--fnet_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained FNet model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument("--save_path", default=None, type=str, required=True, help="Path to the output model.") + args = parser.parse_args() + convert_flax_checkpoint_to_pytorch(args.flax_checkpoint_path, args.fnet_config_file, args.save_path) diff --git a/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..4eab188f2ab7 --- /dev/null +++ b/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Funnel checkpoint.""" + +import argparse + +import torch + +from transformers import FunnelBaseModel, FunnelConfig, FunnelModel, load_tf_weights_in_funnel +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, base_model): + # Initialise PyTorch model + config = FunnelConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + model = FunnelBaseModel(config) if base_model else FunnelModel(config) + + # Load weights from tf checkpoint + load_tf_weights_in_funnel(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--base_model", action="store_true", help="Whether you want just the base model (no decoder) or not." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch( + args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.base_model + ) diff --git a/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..33f9dabed07f --- /dev/null +++ b/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OpenAI GPT checkpoint.""" + +import argparse + +import torch + +from transformers import GPT2Config, GPT2Model, load_tf_weights_in_gpt2 +from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging + + +logging.set_verbosity_info() + + +def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): + # Construct model + if gpt2_config_file == "": + config = GPT2Config() + else: + config = GPT2Config.from_json_file(gpt2_config_file) + model = GPT2Model(config) + + # Load weights from numpy + load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) + + # Save pytorch-model + pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME + pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME + print(f"Save PyTorch model to {pytorch_weights_dump_path}") + torch.save(model.state_dict(), pytorch_weights_dump_path) + print(f"Save configuration file to {pytorch_config_dump_path}") + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--gpt2_config_file", + default="", + type=str, + help=( + "An optional config json file corresponding to the pre-trained OpenAI model. \n" + "This specifies the model architecture." + ), + ) + args = parser.parse_args() + convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py b/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py new file mode 100644 index 000000000000..3db22857293c --- /dev/null +++ b/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# Copyright 2021 The Eleuther AI and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert GPT Neo checkpoint.""" + +import argparse +import json + +from transformers import GPTNeoConfig, GPTNeoForCausalLM, load_tf_weights_in_gpt_neo +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config_json = json.load(open(config_file, "r")) + config = GPTNeoConfig( + hidden_size=config_json["n_embd"], + num_layers=config_json["n_layer"], + num_heads=config_json["n_head"], + attention_types=config_json["attention_types"], + max_position_embeddings=config_json["n_positions"], + resid_dropout=config_json["res_dropout"], + embed_dropout=config_json["embed_dropout"], + attention_dropout=config_json["attn_dropout"], + ) + print(f"Building PyTorch model from configuration: {config}") + model = GPTNeoForCausalLM(config) + + # Load weights from tf checkpoint + load_tf_weights_in_gpt_neo(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained mesh-tf model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py b/src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py new file mode 100644 index 000000000000..182d66b9af28 --- /dev/null +++ b/src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OpenAI Image GPT checkpoints.""" + +import argparse + +import torch + +from transformers import ImageGPTConfig, ImageGPTForCausalLM, load_tf_weights_in_imagegpt +from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging + + +logging.set_verbosity_info() + + +def convert_imagegpt_checkpoint_to_pytorch(imagegpt_checkpoint_path, model_size, pytorch_dump_folder_path): + # Construct configuration depending on size + MODELS = {"small": (512, 8, 24), "medium": (1024, 8, 36), "large": (1536, 16, 48)} + n_embd, n_head, n_layer = MODELS[model_size] # set model hyperparameters + config = ImageGPTConfig(n_embd=n_embd, n_layer=n_layer, n_head=n_head) + model = ImageGPTForCausalLM(config) + + # Load weights from numpy + load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path) + + # Save pytorch-model + pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME + pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME + print(f"Save PyTorch model to {pytorch_weights_dump_path}") + torch.save(model.state_dict(), pytorch_weights_dump_path) + print(f"Save configuration file to {pytorch_config_dump_path}") + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--imagegpt_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the TensorFlow checkpoint path.", + ) + parser.add_argument( + "--model_size", + default=None, + type=str, + required=True, + help="Size of the model (can be either 'small', 'medium' or 'large').", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_imagegpt_checkpoint_to_pytorch( + args.imagegpt_checkpoint_path, args.model_size, args.pytorch_dump_folder_path + ) diff --git a/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..1dd77bc36f80 --- /dev/null +++ b/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert LXMERT checkpoint.""" + +import argparse + +import torch + +from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = LxmertConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + model = LxmertForPreTraining(config) + + # Load weights from tf checkpoint + load_tf_weights_in_lxmert(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..022a9d036cdb --- /dev/null +++ b/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,58 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch + +from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path): + # Initialise PyTorch model + config = MobileBertConfig.from_json_file(mobilebert_config_file) + print(f"Building PyTorch model from configuration: {config}") + model = MobileBertForPreTraining(config) + # Load weights from tf checkpoint + model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path) + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--mobilebert_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained MobileBERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.mobilebert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..1b53bbeab475 --- /dev/null +++ b/src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert MobileNetV1 checkpoints from the tensorflow/models library.""" + +import argparse +import json +import re +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + MobileNetV1Config, + MobileNetV1ForImageClassification, + MobileNetV1ImageProcessor, + load_tf_weights_in_mobilenet_v1, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_mobilenet_v1_config(model_name): + config = MobileNetV1Config(layer_norm_eps=0.001) + + if "_quant" in model_name: + raise ValueError("Quantized models are not supported.") + + matches = re.match(r"^mobilenet_v1_([^_]*)_([^_]*)$", model_name) + if matches: + config.depth_multiplier = float(matches[1]) + config.image_size = int(matches[2]) + + # The TensorFlow version of MobileNetV1 predicts 1001 classes instead of + # the usual 1000. The first class (index 0) is "background". + config.num_labels = 1001 + filename = "imagenet-1k-id2label.json" + repo_id = "huggingface/label-files" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k) + 1: v for k, v in id2label.items()} + id2label[0] = "background" + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + return config + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_movilevit_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our MobileNetV1 structure. + """ + config = get_mobilenet_v1_config(model_name) + + # Load 🤗 model + model = MobileNetV1ForImageClassification(config).eval() + + # Load weights from TensorFlow checkpoint + load_tf_weights_in_mobilenet_v1(model, config, checkpoint_path) + + # Check outputs on an image, prepared by MobileNetV1ImageProcessor + image_processor = MobileNetV1ImageProcessor( + crop_size={"width": config.image_size, "height": config.image_size}, + size={"shortest_edge": config.image_size + 32}, + ) + encoding = image_processor(images=prepare_img(), return_tensors="pt") + outputs = model(**encoding) + logits = outputs.logits + + assert logits.shape == (1, 1001) + + if model_name == "mobilenet_v1_1.0_224": + expected_logits = torch.tensor([-4.1739, -1.1233, 3.1205]) + elif model_name == "mobilenet_v1_0.75_192": + expected_logits = torch.tensor([-3.9440, -2.3141, -0.3333]) + else: + expected_logits = None + + if expected_logits is not None: + assert torch.allclose(logits[0, :3], expected_logits, atol=1e-4) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print("Pushing to the hub...") + repo_id = "google/" + model_name + image_processor.push_to_hub(repo_id) + model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="mobilenet_v1_1.0_224", + type=str, + help="Name of the MobileNetV1 model you'd like to convert. Should in the form 'mobilenet_v1__'.", + ) + parser.add_argument( + "--checkpoint_path", required=True, type=str, help="Path to the original TensorFlow checkpoint (.ckpt file)." + ) + parser.add_argument( + "--pytorch_dump_folder_path", required=True, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_movilevit_checkpoint( + args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub + ) diff --git a/src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..1fdb9783ccf0 --- /dev/null +++ b/src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert MobileNetV2 checkpoints from the tensorflow/models library.""" + +import argparse +import json +import re +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + MobileNetV2Config, + MobileNetV2ForImageClassification, + MobileNetV2ForSemanticSegmentation, + MobileNetV2ImageProcessor, + load_tf_weights_in_mobilenet_v2, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_mobilenet_v2_config(model_name): + config = MobileNetV2Config(layer_norm_eps=0.001) + + if "quant" in model_name: + raise ValueError("Quantized models are not supported.") + + matches = re.match(r"^.*mobilenet_v2_([^_]*)_([^_]*)$", model_name) + if matches: + config.depth_multiplier = float(matches[1]) + config.image_size = int(matches[2]) + + if model_name.startswith("deeplabv3_"): + config.output_stride = 8 + config.num_labels = 21 + filename = "pascal-voc-id2label.json" + else: + # The TensorFlow version of MobileNetV2 predicts 1001 classes instead + # of the usual 1000. The first class (index 0) is "background". + config.num_labels = 1001 + filename = "imagenet-1k-id2label.json" + + repo_id = "huggingface/label-files" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + + if config.num_labels == 1001: + id2label = {int(k) + 1: v for k, v in id2label.items()} + id2label[0] = "background" + else: + id2label = {int(k): v for k, v in id2label.items()} + + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + return config + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_movilevit_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our MobileNetV2 structure. + """ + config = get_mobilenet_v2_config(model_name) + + # Load 🤗 model + if model_name.startswith("deeplabv3_"): + model = MobileNetV2ForSemanticSegmentation(config).eval() + else: + model = MobileNetV2ForImageClassification(config).eval() + + # Load weights from TensorFlow checkpoint + load_tf_weights_in_mobilenet_v2(model, config, checkpoint_path) + + # Check outputs on an image, prepared by MobileNetV2ImageProcessor + image_processor = MobileNetV2ImageProcessor( + crop_size={"width": config.image_size, "height": config.image_size}, + size={"shortest_edge": config.image_size + 32}, + ) + encoding = image_processor(images=prepare_img(), return_tensors="pt") + outputs = model(**encoding) + logits = outputs.logits + + if model_name.startswith("deeplabv3_"): + assert logits.shape == (1, 21, 65, 65) + + if model_name == "deeplabv3_mobilenet_v2_1.0_513": + expected_logits = torch.tensor( + [ + [[17.5790, 17.7581, 18.3355], [18.3257, 18.4230, 18.8973], [18.6169, 18.8650, 19.2187]], + [[-2.1595, -2.0977, -2.3741], [-2.4226, -2.3028, -2.6835], [-2.7819, -2.5991, -2.7706]], + [[4.2058, 4.8317, 4.7638], [4.4136, 5.0361, 4.9383], [4.5028, 4.9644, 4.8734]], + ] + ) + + else: + raise ValueError(f"Unknown model name: {model_name}") + + assert torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-4) + else: + assert logits.shape == (1, 1001) + + if model_name == "mobilenet_v2_1.4_224": + expected_logits = torch.tensor([0.0181, -1.0015, 0.4688]) + elif model_name == "mobilenet_v2_1.0_224": + expected_logits = torch.tensor([0.2445, -1.1993, 0.1905]) + elif model_name == "mobilenet_v2_0.75_160": + expected_logits = torch.tensor([0.2482, 0.4136, 0.6669]) + elif model_name == "mobilenet_v2_0.35_96": + expected_logits = torch.tensor([0.1451, -0.4624, 0.7192]) + else: + expected_logits = None + + if expected_logits is not None: + assert torch.allclose(logits[0, :3], expected_logits, atol=1e-4) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print("Pushing to the hub...") + repo_id = "google/" + model_name + image_processor.push_to_hub(repo_id) + model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="mobilenet_v2_1.0_224", + type=str, + help="Name of the MobileNetV2 model you'd like to convert. Should in the form 'mobilenet_v2__'.", + ) + parser.add_argument( + "--checkpoint_path", required=True, type=str, help="Path to the original TensorFlow checkpoint (.ckpt file)." + ) + parser.add_argument( + "--pytorch_dump_folder_path", required=True, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_movilevit_checkpoint( + args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub + ) diff --git a/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..39653e4b1c77 --- /dev/null +++ b/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2024 The MyT5 authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert MyT5 checkpoint.""" + +import argparse + +from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 +from transformers.utils import logging + + +logging.set_verbosity_info() + + +# Copied from transformers.models.t5.convert_t5_original_tf_checkpoint_to_pytorch.convert_tf_checkpoint_to_pytorch +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = T5Config.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + model = T5ForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_tf_weights_in_t5(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained MyT5 model. \nThis specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..3d5218c20426 --- /dev/null +++ b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OpenAI GPT checkpoint.""" + +import argparse + +import torch + +from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt +from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging + + +logging.set_verbosity_info() + + +def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): + # Construct model + if openai_config_file == "": + config = OpenAIGPTConfig() + else: + config = OpenAIGPTConfig.from_json_file(openai_config_file) + model = OpenAIGPTModel(config) + + # Load weights from numpy + load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) + + # Save pytorch-model + pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME + pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME + print(f"Save PyTorch model to {pytorch_weights_dump_path}") + torch.save(model.state_dict(), pytorch_weights_dump_path) + print(f"Save configuration file to {pytorch_config_dump_path}") + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--openai_checkpoint_folder_path", + default=None, + type=str, + required=True, + help="Path to the TensorFlow checkpoint path.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--openai_config_file", + default="", + type=str, + help=( + "An optional config json file corresponding to the pre-trained OpenAI model. \n" + "This specifies the model architecture." + ), + ) + args = parser.parse_args() + convert_openai_checkpoint_to_pytorch( + args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path + ) diff --git a/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py b/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py new file mode 100644 index 000000000000..ea766c366f34 --- /dev/null +++ b/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py @@ -0,0 +1,406 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OWL-ViT checkpoints from the original repository. URL: +https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit""" + +import argparse +import collections + +import jax +import jax.numpy as jnp +import torch +import torch.nn as nn +from clip.model import CLIP +from flax.training import checkpoints +from huggingface_hub import Repository + +from transformers import ( + CLIPTokenizer, + OwlViTConfig, + OwlViTForObjectDetection, + OwlViTImageProcessor, + OwlViTModel, + OwlViTProcessor, +) + + +CONFIGS = { + "vit_b32": { + "embed_dim": 512, + "image_resolution": 768, + "context_length": 16, + "vocab_size": 49408, + "vision_layers": 12, + "vision_width": 768, + "vision_patch_size": 32, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12, + }, + "vit_b16": { + "embed_dim": 512, + "image_resolution": 768, + "context_length": 16, + "vocab_size": 49408, + "vision_layers": 12, + "vision_width": 768, + "vision_patch_size": 16, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12, + }, + "vit_l14": { + "embed_dim": 768, + "image_resolution": 840, + "context_length": 16, + "vocab_size": 49408, + "vision_layers": 24, + "vision_width": 1024, + "vision_patch_size": 14, + "transformer_width": 768, + "transformer_heads": 12, + "transformer_layers": 12, + }, +} + + +def flatten_nested_dict(params, parent_key="", sep="/"): + items = [] + + for k, v in params.items(): + new_key = parent_key + sep + k if parent_key else k + + if isinstance(v, collections.MutableMapping): + items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def to_f32(params): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params) + + +def copy_attn_layer(hf_attn_layer, pt_attn_layer): + q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0) + q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0) + + out_proj_weights = pt_attn_layer.out_proj.weight + out_proj_bias = pt_attn_layer.out_proj.bias + + hf_attn_layer.q_proj.weight.data = q_proj + hf_attn_layer.q_proj.bias.data = q_proj_bias + + hf_attn_layer.k_proj.weight.data = k_proj + hf_attn_layer.k_proj.bias.data = k_proj_bias + + hf_attn_layer.v_proj.weight.data = v_proj + hf_attn_layer.v_proj.bias.data = v_proj_bias + + hf_attn_layer.out_proj.weight = out_proj_weights + hf_attn_layer.out_proj.bias = out_proj_bias + + +def copy_mlp(hf_mlp, pt_mlp): + copy_linear(hf_mlp.fc1, pt_mlp.c_fc) + copy_linear(hf_mlp.fc2, pt_mlp.c_proj) + + +def copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + +def copy_layer(hf_layer, pt_layer): + # copy layer norms + copy_linear(hf_layer.layer_norm1, pt_layer.ln_1) + copy_linear(hf_layer.layer_norm2, pt_layer.ln_2) + + # copy MLP + copy_mlp(hf_layer.mlp, pt_layer.mlp) + + # copy attn + copy_attn_layer(hf_layer.self_attn, pt_layer.attn) + + +def copy_layers(hf_layers, pt_layers): + for hf_layer, pt_layer in zip(hf_layers, pt_layers): + copy_layer(hf_layer, pt_layer) + + +def copy_encoder(hf_encoder, pt_model): + # copy embeds + hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight + hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding + + # copy layer norm + copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final) + + # copy hidden layers + copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks) + + +def copy_text_model_and_projection(hf_model, pt_model): + # copy projection + hf_model.text_projection.weight.data = pt_model.text_projection.data.T + + # copy text encoder + copy_encoder(hf_model.text_model, pt_model) + + +def copy_vision_model_and_projection(hf_model, pt_model): + # copy projection + hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T + + # copy layer norms + copy_linear(hf_model.vision_model.pre_layernorm, pt_model.visual.ln_pre) + copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post) + + # copy embeds + hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data + hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding + hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data + + # copy encoder + copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks) + + +def copy_class_merge_token(hf_model, flax_params): + flax_class_token_params = flatten_nested_dict(flax_params["backbone"]["merged_class_token"]) + + weight = torch.from_numpy(flax_class_token_params["scale"]) + bias = torch.from_numpy(flax_class_token_params["bias"]) + hf_model.layer_norm.weight = nn.Parameter(weight) + hf_model.layer_norm.bias = nn.Parameter(bias) + + +def copy_class_box_heads(hf_model, flax_params): + pt_params = hf_model.state_dict() + new_params = {} + + # Rename class prediction head flax params to pytorch HF + flax_class_params = flatten_nested_dict(flax_params["class_head"]) + + for flax_key, v in flax_class_params.items(): + torch_key = flax_key.replace("/", ".") + torch_key = torch_key.replace(".kernel", ".weight") + torch_key = torch_key.replace("Dense_0", "dense0") + torch_key = "class_head." + torch_key + + if "weight" in torch_key and v.ndim == 2: + v = v.T + + new_params[torch_key] = nn.Parameter(torch.from_numpy(v)) + + # Rename box prediction box flax params to pytorch HF + flax_box_params = flatten_nested_dict(flax_params["obj_box_head"]) + + for flax_key, v in flax_box_params.items(): + torch_key = flax_key.replace("/", ".") + torch_key = torch_key.replace(".kernel", ".weight") + torch_key = torch_key.replace("_", "").lower() + torch_key = "box_head." + torch_key + + if "weight" in torch_key and v.ndim == 2: + v = v.T + + new_params[torch_key] = nn.Parameter(torch.from_numpy(v)) + + # Copy flax params to PyTorch params + for name, param in new_params.items(): + if name in pt_params: + pt_params[name].copy_(param) + + +def copy_flax_attn_params(hf_backbone, flax_attn_params): + for k, v in flax_attn_params.items(): + if k.startswith("transformer"): + torch_key = k.replace("transformer.resblocks", "text_model.encoder.layers") + else: + torch_key = k.replace("visual.transformer.resblocks", "vision_model.encoder.layers") + + torch_key = torch_key.replace("attn", "self_attn") + torch_key = torch_key.replace("key", "k_proj") + torch_key = torch_key.replace("value", "v_proj") + torch_key = torch_key.replace("query", "q_proj") + torch_key = torch_key.replace("out", "out_proj") + + if "bias" in torch_key and v.ndim == 2: + shape = v.shape[0] * v.shape[1] + v = v.reshape(shape) + + if "weight" in torch_key and "out" in torch_key: + shape = (v.shape[0] * v.shape[1], v.shape[2]) + v = v.reshape(shape).T + + if "weight" in torch_key and "out" not in torch_key: + shape = (v.shape[0], v.shape[1] * v.shape[2]) + v = v.reshape(shape).T + + # Copy flax CLIP attn params to HF PyTorch params + v = torch.from_numpy(v) + hf_backbone.state_dict()[torch_key].copy_(v) + + +def _convert_attn_layers(params): + new_params = {} + processed_attn_layers = [] + + for k, v in params.items(): + if "attn." in k: + base = k[: k.rindex("attn.") + 5] + if base in processed_attn_layers: + continue + + processed_attn_layers.append(base) + dim = params[base + "out.weight"].shape[-1] + new_params[base + "out_proj.weight"] = params[base + "out.weight"].reshape(dim, dim).T + new_params[base + "out_proj.bias"] = params[base + "out.bias"] + else: + new_params[k] = v + return new_params + + +def convert_clip_backbone(flax_params, torch_config): + torch_model = CLIP(**torch_config) + torch_model.eval() + torch_clip_params = torch_model.state_dict() + + flax_clip_params = flatten_nested_dict(flax_params["backbone"]["clip"]) + new_torch_params = {} + + for flax_key, v in flax_clip_params.items(): + torch_key = flax_key.replace("/", ".") + torch_key = torch_key.replace("text.token_embedding.embedding", "token_embedding.kernel") + + if ( + torch_key.startswith("text.transformer") + or torch_key.startswith("text.text_projection") + or torch_key.startswith("text.ln_final") + or torch_key.startswith("text.positional_embedding") + ): + torch_key = torch_key[5:] + + torch_key = torch_key.replace("text_projection.kernel", "text_projection") + torch_key = torch_key.replace("visual.proj.kernel", "visual.proj") + torch_key = torch_key.replace(".scale", ".weight") + torch_key = torch_key.replace(".kernel", ".weight") + + if "conv" in torch_key or "downsample.0.weight" in torch_key: + v = v.transpose(3, 2, 0, 1) + + elif "weight" in torch_key and v.ndim == 2 and "embedding" not in torch_key: + # Fully connected layers are transposed, embeddings are not + v = v.T + + new_torch_params[torch_key] = v + + attn_params = _convert_attn_layers(new_torch_params) + new_torch_params.update(attn_params) + attn_params = {} + + # Copy flax CLIP backbone params to PyTorch params + for name, param in new_torch_params.items(): + if name in torch_clip_params: + new_param = torch.from_numpy(param) + torch_clip_params[name].copy_(new_param) + else: + attn_params[name] = param + + return torch_clip_params, torch_model, attn_params + + +@torch.no_grad() +def convert_owlvit_checkpoint(pt_backbone, flax_params, attn_params, pytorch_dump_folder_path, config_path=None): + """ + Copy/paste/tweak model's weights to transformers design. + """ + repo = Repository(pytorch_dump_folder_path, clone_from=f"google/{pytorch_dump_folder_path}") + repo.git_pull() + + if config_path is not None: + config = OwlViTConfig.from_pretrained(config_path) + else: + config = OwlViTConfig() + + hf_backbone = OwlViTModel(config).eval() + hf_model = OwlViTForObjectDetection(config).eval() + + copy_text_model_and_projection(hf_backbone, pt_backbone) + copy_vision_model_and_projection(hf_backbone, pt_backbone) + hf_backbone.logit_scale = pt_backbone.logit_scale + copy_flax_attn_params(hf_backbone, attn_params) + + hf_model.owlvit = hf_backbone + copy_class_merge_token(hf_model, flax_params) + copy_class_box_heads(hf_model, flax_params) + + # Save HF model + hf_model.save_pretrained(repo.local_dir) + + # Initialize image processor + image_processor = OwlViTImageProcessor( + size=config.vision_config.image_size, crop_size=config.vision_config.image_size + ) + # Initialize tokenizer + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", pad_token="!", model_max_length=16) + + # Initialize processor + processor = OwlViTProcessor(image_processor=image_processor, tokenizer=tokenizer) + image_processor.save_pretrained(repo.local_dir) + processor.save_pretrained(repo.local_dir) + + repo.git_add() + repo.git_commit("Upload model and processor") + repo.git_push() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--owlvit_version", + default=None, + type=str, + required=True, + help="OWL-ViT model name [clip_b16, clip_b32, clip_l14].", + ) + parser.add_argument( + "--owlvit_checkpoint", default=None, type=str, required=True, help="Path to flax model checkpoint." + ) + parser.add_argument("--hf_config", default=None, type=str, required=True, help="Path to HF model config.") + parser.add_argument( + "--pytorch_dump_folder_path", default="hf_model", type=str, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + + # Initialize PyToch clip model + model_name = args.owlvit_version + if model_name == "clip_b16": + torch_config = CONFIGS["vit_b16"] + elif model_name == "clip_b32": + torch_config = CONFIGS["vit_b32"] + elif model_name == "clip_l14": + torch_config = CONFIGS["vit_l14"] + + # Load from checkpoint and convert params to float-32 + variables = checkpoints.restore_checkpoint(args.owlvit_checkpoint, target=None)["optimizer"]["target"] + flax_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables) + del variables + + # Convert CLIP backbone + pt_backbone_params, clip_pt, attn_params = convert_clip_backbone(flax_params, torch_config) + + convert_owlvit_checkpoint(clip_pt, flax_params, attn_params, args.pytorch_dump_folder_path, args.hf_config) diff --git a/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py b/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py new file mode 100644 index 000000000000..9251c9a92ac6 --- /dev/null +++ b/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2020 Google and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from pathlib import Path + +import tensorflow as tf +import torch +from tqdm import tqdm + +from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer +from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params + + +PATTERNS = [ + # replace left string with right string to get the relevant state_dict key (identical state dict to bart) + ["memory_attention", "encoder_attn"], + ["attention", "attn"], + ["/", "."], + [".LayerNorm.gamma", "_layer_norm.weight"], + [".LayerNorm.beta", "_layer_norm.bias"], + ["r.layer_", "r.layers."], + ["output_proj", "out_proj"], + ["ffn.dense_1.", "fc2."], + ["ffn.dense.", "fc1."], + ["ffn_layer_norm", "final_layer_norm"], + ["kernel", "weight"], + ["encoder_layer_norm.", "encoder.layer_norm."], + ["decoder_layer_norm.", "decoder.layer_norm."], + ["embeddings.weights", "shared.weight"], +] + + +def rename_state_dict_key(k): + for pegasus_name, hf_name in PATTERNS: + k = k.replace(pegasus_name, hf_name) + return k + + +# See appendix C of paper for all hyperparams + + +def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration: + cfg_kwargs = DEFAULTS.copy() + cfg_kwargs.update(cfg_updates) + cfg = PegasusConfig(**cfg_kwargs) + torch_model = PegasusForConditionalGeneration(cfg) + sd = torch_model.model.state_dict() + mapping = {} + for k, v in tf_weights.items(): + new_k = rename_state_dict_key(k) + if new_k not in sd: + raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") + + if "dense" in k or "proj" in new_k: + v = v.T + mapping[new_k] = torch.tensor(v, dtype=sd[new_k].dtype) + assert v.shape == sd[new_k].shape, f"{new_k}, {k}, {v.shape}, {sd[new_k].shape}" + # make sure embedding.padding_idx is respected + mapping["shared.weight"][cfg.pad_token_id] = torch.zeros_like(mapping["shared.weight"][cfg.pad_token_id + 1]) + mapping["encoder.embed_tokens.weight"] = mapping["shared.weight"] + mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"] + empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping} + mapping.update(**empty_biases) + missing, extra = torch_model.model.load_state_dict(mapping, strict=False) + unexpected_missing = [ + k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"] + ] + assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}" + assert extra == [], f"no matches found for the following tf keys {extra}" + return torch_model + + +def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> dict: + init_vars = tf.train.list_variables(path) + tf_weights = {} + ignore_name = ["Adafactor", "global_step"] + for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"): + skip_key = any(pat in name for pat in ignore_name) + if skip_key: + continue + array = tf.train.load_variable(path, name) + tf_weights[name] = array + return tf_weights + + +def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str): + # save tokenizer first + dataset = Path(ckpt_path).parent.name + desired_max_model_length = task_specific_params[f"summarization_{dataset}"]["max_position_embeddings"] + tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length) + assert tok.model_max_length == desired_max_model_length + tok.save_pretrained(save_dir) + + # convert model + tf_weights = get_tf_weights_as_numpy(ckpt_path) + cfg_updates = task_specific_params[f"summarization_{dataset}"] + if dataset == "large": + cfg_updates["task_specific_params"] = task_specific_params + torch_model = convert_pegasus(tf_weights, cfg_updates) + torch_model.save_pretrained(save_dir) + sd = torch_model.state_dict() + sd.pop("model.decoder.embed_positions.weight") + sd.pop("model.encoder.embed_positions.weight") + torch.save(sd, Path(save_dir) / "pytorch_model.bin") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("tf_ckpt_path", type=str, help="passed to tf.train.list_variables") + parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + if args.save_dir is None: + dataset = Path(args.tf_ckpt_path).parent.name + args.save_dir = os.path.join("pegasus", dataset) + convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir) diff --git a/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py b/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..369388c540f9 --- /dev/null +++ b/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RemBERT checkpoint.""" + +import argparse + +import torch + +from transformers import RemBertConfig, RemBertModel, load_tf_weights_in_rembert +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_rembert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): + # Initialise PyTorch model + config = RemBertConfig.from_json_file(bert_config_file) + print(f"Building PyTorch model from configuration: {str(config)}") + model = RemBertModel(config) + + # Load weights from tf checkpoint + load_tf_weights_in_rembert(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--rembert_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained RemBERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_rembert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.rembert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..d227948e0ee3 --- /dev/null +++ b/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RoFormer checkpoint.""" + +import argparse + +import torch + +from transformers import RoFormerConfig, RoFormerForMaskedLM, load_tf_weights_in_roformer +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): + # Initialise PyTorch model + config = RoFormerConfig.from_json_file(bert_config_file) + print(f"Building PyTorch model from configuration: {config}") + model = RoFormerForMaskedLM(config) + + # Load weights from tf checkpoint + load_tf_weights_in_roformer(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + torch.save(model.state_dict(), pytorch_dump_path, _use_new_zipfile_serialization=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--bert_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py b/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py new file mode 100644 index 000000000000..9286fae776fd --- /dev/null +++ b/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py @@ -0,0 +1,121 @@ +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from torch import nn + +from transformers import Speech2TextConfig, Speech2TextForConditionalGeneration + + +def remove_ignore_keys_(state_dict): + ignore_keys = [ + "encoder.version", + "decoder.version", + "model.encoder.version", + "model.decoder.version", + "decoder.output_projection.weight", + "_float_tensor", + "encoder.embed_positions._float_tensor", + "decoder.embed_positions._float_tensor", + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_keys(s_dict): + keys = list(s_dict.keys()) + for key in keys: + if "transformer_layers" in key: + s_dict[key.replace("transformer_layers", "layers")] = s_dict.pop(key) + elif "subsample" in key: + s_dict[key.replace("subsample", "conv")] = s_dict.pop(key) + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_path): + m2m_100 = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + args = m2m_100["args"] + state_dict = m2m_100["model"] + lm_head_weights = state_dict["decoder.output_projection.weight"] + + remove_ignore_keys_(state_dict) + rename_keys(state_dict) + + vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0] + + tie_embeds = args.share_decoder_input_output_embed + + conv_kernel_sizes = [int(i) for i in args.conv_kernel_sizes.split(",")] + config = Speech2TextConfig( + vocab_size=vocab_size, + max_source_positions=args.max_source_positions, + max_target_positions=args.max_target_positions, + encoder_layers=args.encoder_layers, + decoder_layers=args.decoder_layers, + encoder_attention_heads=args.encoder_attention_heads, + decoder_attention_heads=args.decoder_attention_heads, + encoder_ffn_dim=args.encoder_ffn_embed_dim, + decoder_ffn_dim=args.decoder_ffn_embed_dim, + d_model=args.encoder_embed_dim, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_function="relu", + num_conv_layers=len(conv_kernel_sizes), + conv_channels=args.conv_channels, + conv_kernel_sizes=conv_kernel_sizes, + input_feat_per_channel=args.input_feat_per_channel, + input_channels=args.input_channels, + tie_word_embeddings=tie_embeds, + num_beams=5, + max_length=200, + use_cache=True, + decoder_start_token_id=2, + early_stopping=True, + ) + + model = Speech2TextForConditionalGeneration(config) + missing, unexpected = model.model.load_state_dict(state_dict, strict=False) + if len(missing) > 0 and not set(missing) <= { + "encoder.embed_positions.weights", + "decoder.embed_positions.weights", + }: + raise ValueError( + "Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing," + f" but all the following weights are missing {missing}" + ) + + if tie_embeds: + model.lm_head = make_linear_from_emb(model.model.decoder.embed_tokens) + else: + model.lm_head.weight.data = lm_head_weights + + model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--fairseq_path", type=str, help="Path to the fairseq model (.pt) file.") + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + convert_fairseq_s2t_checkpoint_to_tfms(args.fairseq_path, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..71d304ea96c6 --- /dev/null +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert SwitchTransformersX checkpoints from the original repository to JAX/FLAX model.""" + +import argparse +import re + +from flax.traverse_util import flatten_dict, unflatten_dict +from t5x import checkpoints + +from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration +from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model +from transformers.utils import logging + + +logging.set_verbosity_info() + + +# should not include what is already done by the `from_pt` argument +MOE_LAYER_NAME_MAPPING = { + "/attention/": "/0/SelfAttention/", + "/self_attention/": "/0/SelfAttention/", + "/encoder_decoder_attention/": "/1/EncDecAttention/", + "value": "v", + "query": "q", + "key": "k", + "out": "o", + "pre_self_attention_layer_norm": "0/layer_norm", + "pre_cross_attention_layer_norm": "1/layer_norm", + "pre_attention_layer_norm": "0/layer_norm", # previously 1, but seems wrong + "token_embedder": "shared", + "encoder_norm": "final_layer_norm", + "decoder_norm": "final_layer_norm", + "relpos_bias/rel_embedding": "block/0/layer/0/SelfAttention/relative_attention_bias/weight", + "router/router_weights/w/": "router/classifier/", + "roer/roer_weights/w/": "router/classifier/", + "logits_dense": "lm_head", +} + + +def rename_keys(s_dict): + # 1. in HF T5, we have block.{x}.layer.{y}. which corresponds to layer.{x} in + # the original model + keys = list(s_dict.keys()) + for key in keys: + layer_to_block_of_layer = r".*/layers_(\d+)" + new_key = key + if re.match(layer_to_block_of_layer, key): + new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", new_key) + + layer_to_block_of_layer = r"(encoder|decoder)\/" + + if re.match(layer_to_block_of_layer, key): + groups = re.match(layer_to_block_of_layer, new_key).groups() + if groups[0] == "encoder": + new_key = re.sub(r"/mlp/", r"/1/mlp/", new_key) + new_key = re.sub(r"/pre_mlp_layer_norm/", r"/1/layer_norm/", new_key) + + elif groups[0] == "decoder": + new_key = re.sub(r"/mlp/", r"/2/mlp/", new_key) + new_key = re.sub(r"/pre_mlp_layer_norm/", r"/2/layer_norm/", new_key) + + # 2. Convert other classic mappings + for old_key, temp_key in MOE_LAYER_NAME_MAPPING.items(): + if old_key in new_key: + new_key = new_key.replace(old_key, temp_key) + + print(f"{key} -> {new_key}") + s_dict[new_key] = s_dict.pop(key) + + if "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict: + s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ + "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" + ].T + if "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict: + s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ + "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" + ].T + + # 3. Take extra care of the EXPERTS layer + for key in list(s_dict.keys()): + if "expert" in key: + num_experts = s_dict[key].shape[0] + expert_weihts = s_dict[key] + for idx in range(num_experts): + s_dict[key.replace("expert/", f"experts/expert_{idx}/")] = expert_weihts[idx] + print(f"{key} -> {key.replace('expert/', f'experts/expert_{idx}/')}") + + s_dict.pop(key) + + return s_dict + + +GIN_TO_CONFIG_MAPPING = { + "NUM_ENCODER_LAYERS": "num_layers", + "NUM_DECODER_LAYERS": "num_decoder_layers", + "NUM_HEADS": "num_heads", + "HEAD_DIM": "d_kv", + "EMBED_DIM": "d_model", + "MLP_DIM": "d_ff", + "NUM_SELECTED_EXPERTS": "num_selected_experts", + "NUM_ENCODER_SPARSE_LAYERS": "num_sparse_encoder_layers", + "NUM_DECODER_SPARSE_LAYERS": "num_sparse_decoder_layers", + "dense.MlpBlock.activations": "feed_forward_proj", +} + + +def convert_gin_to_config(gin_file, num_experts): + # Convert a google style config to the hugging face format + import regex as re + + with open(gin_file, "r") as f: + raw_gin = f.read() + + regex_match = re.findall(r"(.*) = ([0-9.]*)", raw_gin) + args = {} + for param, value in regex_match: + if param in GIN_TO_CONFIG_MAPPING and value != "": + args[GIN_TO_CONFIG_MAPPING[param]] = float(value) if "." in value else int(value) + + activation = re.findall(r"(.*activations) = \(\'(.*)\',\)", raw_gin)[0] + args[GIN_TO_CONFIG_MAPPING[activation[0]]] = str(activation[1]) + + args["num_experts"] = num_experts + config = SwitchTransformersConfig(**args) + return config + + +def convert_flax_checkpoint_to_pytorch( + flax_checkpoint_path, config_file, gin_file=None, pytorch_dump_path="./", num_experts=8 +): + # Initialise PyTorch model + + print(f"Loading flax weights from : {flax_checkpoint_path}") + flax_params = checkpoints.load_t5x_checkpoint(flax_checkpoint_path) + + if gin_file is not None: + config = convert_gin_to_config(gin_file, num_experts) + else: + config = SwitchTransformersConfig.from_pretrained(config_file) + + pt_model = SwitchTransformersForConditionalGeneration(config) + + flax_params = flax_params["target"] + flax_params = flatten_dict(flax_params, sep="/") + flax_params = rename_keys(flax_params) + flax_params = unflatten_dict(flax_params, sep="/") + + # Load the flax params in the PT model + load_flax_weights_in_pytorch_model(pt_model, flax_params) + + print(f"Save PyTorch model to {pytorch_dump_path}") + pt_model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--switch_t5x_checkpoint_path", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained SwitchTransformers model. \nThis specifies the" + " model architecture. If not provided, a `gin_file` has to be provided." + ), + ) + parser.add_argument( + "--gin_file", + default=None, + type=str, + required=False, + help="Path to the gin config file. If not provided, a `config_file` has to be passed ", + ) + parser.add_argument( + "--config_name", default=None, type=str, required=False, help="Config name of SwitchTransformers model." + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output pytorch model." + ) + parser.add_argument("--num_experts", default=8, type=int, required=False, help="Number of experts") + args = parser.parse_args() + convert_flax_checkpoint_to_pytorch( + args.switch_t5x_checkpoint_path, + args.config_name, + args.gin_file, + args.pytorch_dump_folder_path, + args.num_experts, + ) diff --git a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..9b1b15857cea --- /dev/null +++ b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2018 The T5 authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert T5 checkpoint.""" + +import argparse + +from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = T5Config.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + model = T5ForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_tf_weights_in_t5(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..34bf77cccd6b --- /dev/null +++ b/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert TAPAS checkpoint.""" + +import argparse + +from transformers import ( + TapasConfig, + TapasForMaskedLM, + TapasForQuestionAnswering, + TapasForSequenceClassification, + TapasModel, + TapasTokenizer, + load_tf_weights_in_tapas, +) +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch( + task, reset_position_index_per_cell, tf_checkpoint_path, tapas_config_file, pytorch_dump_path +): + # Initialise PyTorch model. + # If you want to convert a checkpoint that uses absolute position embeddings, make sure to set reset_position_index_per_cell of + # TapasConfig to False. + + # initialize configuration from json file + config = TapasConfig.from_json_file(tapas_config_file) + # set absolute/relative position embeddings parameter + config.reset_position_index_per_cell = reset_position_index_per_cell + + # set remaining parameters of TapasConfig as well as the model based on the task + if task == "SQA": + model = TapasForQuestionAnswering(config=config) + elif task == "WTQ": + # run_task_main.py hparams + config.num_aggregation_labels = 4 + config.use_answer_as_supervision = True + # hparam_utils.py hparams + config.answer_loss_cutoff = 0.664694 + config.cell_selection_preference = 0.207951 + config.huber_loss_delta = 0.121194 + config.init_cell_selection_weights_to_zero = True + config.select_one_column = True + config.allow_empty_column_selection = False + config.temperature = 0.0352513 + + model = TapasForQuestionAnswering(config=config) + elif task == "WIKISQL_SUPERVISED": + # run_task_main.py hparams + config.num_aggregation_labels = 4 + config.use_answer_as_supervision = False + # hparam_utils.py hparams + config.answer_loss_cutoff = 36.4519 + config.cell_selection_preference = 0.903421 + config.huber_loss_delta = 222.088 + config.init_cell_selection_weights_to_zero = True + config.select_one_column = True + config.allow_empty_column_selection = True + config.temperature = 0.763141 + + model = TapasForQuestionAnswering(config=config) + elif task == "TABFACT": + model = TapasForSequenceClassification(config=config) + elif task == "MLM": + model = TapasForMaskedLM(config=config) + elif task == "INTERMEDIATE_PRETRAINING": + model = TapasModel(config=config) + else: + raise ValueError(f"Task {task} not supported.") + + print(f"Building PyTorch model from configuration: {config}") + # Load weights from tf checkpoint + load_tf_weights_in_tapas(model, config, tf_checkpoint_path) + + # Save pytorch-model (weights and configuration) + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + # Save tokenizer files + print(f"Save tokenizer files to {pytorch_dump_path}") + tokenizer = TapasTokenizer(vocab_file=tf_checkpoint_path[:-10] + "vocab.txt", model_max_length=512) + tokenizer.save_pretrained(pytorch_dump_path) + + print("Used relative position embeddings:", model.config.reset_position_index_per_cell) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--task", default="SQA", type=str, help="Model task for which to convert a checkpoint. Defaults to SQA." + ) + parser.add_argument( + "--reset_position_index_per_cell", + default=False, + action="store_true", + help="Whether to use relative position embeddings or not. Defaults to True.", + ) + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--tapas_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained TAPAS model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch( + args.task, + args.reset_position_index_per_cell, + args.tf_checkpoint_path, + args.tapas_config_file, + args.pytorch_dump_path, + ) diff --git a/src/transformers/models/vivit/convert_vivit_flax_to_pytorch.py b/src/transformers/models/vivit/convert_vivit_flax_to_pytorch.py new file mode 100644 index 000000000000..bf6aa8e4a36b --- /dev/null +++ b/src/transformers/models/vivit/convert_vivit_flax_to_pytorch.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Flax ViViT checkpoints from the original repository to PyTorch. URL: +https://github.com/google-research/scenic/tree/main/scenic/projects/vivit +""" + +import argparse +import json +import os.path +from collections import OrderedDict + +import numpy as np +import requests +import torch +from flax.training.checkpoints import restore_checkpoint +from huggingface_hub import hf_hub_download + +from transformers import VivitConfig, VivitForVideoClassification, VivitImageProcessor +from transformers.image_utils import PILImageResampling + + +def download_checkpoint(path): + url = "https://storage.googleapis.com/scenic-bucket/vivit/kinetics_400/vivit_base_16x2_unfactorized/checkpoint" + + with open(path, "wb") as f: + with requests.get(url, stream=True) as req: + for chunk in req.iter_content(chunk_size=2048): + f.write(chunk) + + +def get_vivit_config() -> VivitConfig: + config = VivitConfig() + + config.num_labels = 400 + repo_id = "huggingface/label-files" + filename = "kinetics400-id2label.json" + + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + return config + + +# We will verify our results on a video of eating spaghetti +# Frame indices used: [ 47, 51, 55, 59, 63, 67, 71, 75, 80, 84, 88, 92, 96, 100, 104, 108, 113, 117, +# 121, 125, 129, 133, 137, 141, 146, 150, 154, 158, 162, 166, 170, 174] +def prepare_video(): + file = hf_hub_download( + repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti_32_frames.npy", repo_type="dataset" + ) + video = np.load(file) + return list(video) + + +def transform_attention(current: np.ndarray): + if np.ndim(current) == 2: + return transform_attention_bias(current) + + elif np.ndim(current) == 3: + return transform_attention_kernel(current) + + else: + raise Exception(f"Invalid number of dimensions: {np.ndim(current)}") + + +def transform_attention_bias(current: np.ndarray): + return current.flatten() + + +def transform_attention_kernel(current: np.ndarray): + return np.reshape(current, (current.shape[0], current.shape[1] * current.shape[2])).T + + +def transform_attention_output_weight(current: np.ndarray): + return np.reshape(current, (current.shape[0] * current.shape[1], current.shape[2])).T + + +def transform_state_encoder_block(state_dict, i): + state = state_dict["optimizer"]["target"]["Transformer"][f"encoderblock_{i}"] + + prefix = f"encoder.layer.{i}." + new_state = { + prefix + "intermediate.dense.bias": state["MlpBlock_0"]["Dense_0"]["bias"], + prefix + "intermediate.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_0"]["kernel"]), + prefix + "output.dense.bias": state["MlpBlock_0"]["Dense_1"]["bias"], + prefix + "output.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_1"]["kernel"]), + prefix + "layernorm_before.bias": state["LayerNorm_0"]["bias"], + prefix + "layernorm_before.weight": state["LayerNorm_0"]["scale"], + prefix + "layernorm_after.bias": state["LayerNorm_1"]["bias"], + prefix + "layernorm_after.weight": state["LayerNorm_1"]["scale"], + prefix + "attention.attention.query.bias": transform_attention( + state["MultiHeadDotProductAttention_0"]["query"]["bias"] + ), + prefix + "attention.attention.query.weight": transform_attention( + state["MultiHeadDotProductAttention_0"]["query"]["kernel"] + ), + prefix + "attention.attention.key.bias": transform_attention( + state["MultiHeadDotProductAttention_0"]["key"]["bias"] + ), + prefix + "attention.attention.key.weight": transform_attention( + state["MultiHeadDotProductAttention_0"]["key"]["kernel"] + ), + prefix + "attention.attention.value.bias": transform_attention( + state["MultiHeadDotProductAttention_0"]["value"]["bias"] + ), + prefix + "attention.attention.value.weight": transform_attention( + state["MultiHeadDotProductAttention_0"]["value"]["kernel"] + ), + prefix + "attention.output.dense.bias": state["MultiHeadDotProductAttention_0"]["out"]["bias"], + prefix + "attention.output.dense.weight": transform_attention_output_weight( + state["MultiHeadDotProductAttention_0"]["out"]["kernel"] + ), + } + + return new_state + + +def get_n_layers(state_dict): + return sum([1 if "encoderblock_" in k else 0 for k in state_dict["optimizer"]["target"]["Transformer"]]) + + +def transform_state(state_dict, classification_head=False): + transformer_layers = get_n_layers(state_dict) + + new_state = OrderedDict() + + new_state["layernorm.bias"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["bias"] + new_state["layernorm.weight"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["scale"] + + new_state["embeddings.patch_embeddings.projection.weight"] = np.transpose( + state_dict["optimizer"]["target"]["embedding"]["kernel"], (4, 3, 0, 1, 2) + ) + new_state["embeddings.patch_embeddings.projection.bias"] = state_dict["optimizer"]["target"]["embedding"]["bias"] + + new_state["embeddings.cls_token"] = state_dict["optimizer"]["target"]["cls"] + new_state["embeddings.position_embeddings"] = state_dict["optimizer"]["target"]["Transformer"]["posembed_input"][ + "pos_embedding" + ] + + for i in range(transformer_layers): + new_state.update(transform_state_encoder_block(state_dict, i)) + + if classification_head: + new_state = {"vivit." + k: v for k, v in new_state.items()} + new_state["classifier.weight"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["kernel"]) + new_state["classifier.bias"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["bias"]) + + return {k: torch.tensor(v) for k, v in new_state.items()} + + +# checks that image processor settings are the same as in the original implementation +# original: https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/data/video_tfrecord_dataset.py +# dataset specific config: +# https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/configs/kinetics400/vivit_base_k400.py +def get_processor() -> VivitImageProcessor: + extractor = VivitImageProcessor() + + assert extractor.do_resize is True + assert extractor.size == {"shortest_edge": 256} + assert extractor.do_center_crop is True + assert extractor.crop_size == {"width": 224, "height": 224} + assert extractor.resample == PILImageResampling.BILINEAR + + # here: https://github.com/deepmind/dmvr/blob/master/dmvr/modalities.py + # one can seen that add_image has default values for normalization_mean and normalization_std set to 0 and 1 + # which effectively means no normalization (and ViViT does not overwrite those when calling this func) + assert extractor.do_normalize is False + assert extractor.do_rescale is True + assert extractor.rescale_factor == 1 / 255 + + # zero-centering = True in original implementation + assert extractor.do_zero_centering is True + + return extractor + + +def convert(output_path: str): + flax_model_path = "checkpoint" + + if not os.path.exists(flax_model_path): + download_checkpoint(flax_model_path) + + state_dict = restore_checkpoint(flax_model_path, None) + new_state = transform_state(state_dict, classification_head=True) + + config = get_vivit_config() + + assert config.image_size == 224 + assert config.num_frames == 32 + + model = VivitForVideoClassification(config) + model.load_state_dict(new_state) + model.eval() + + extractor = get_processor() + + video = prepare_video() + inputs = extractor(video, return_tensors="pt") + + outputs = model(**inputs) + + expected_shape = torch.Size([1, 400]) + expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658]) + + assert outputs.logits.shape == expected_shape + assert torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4), outputs.logits[0, :5] + + model.save_pretrained(output_path) + extractor.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--output_model_name", "-o", type=str, help="Output path for the converted HuggingFace model") + + args = parser.parse_args() + convert(args.output_model_name) diff --git a/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..a15c5f22ad68 --- /dev/null +++ b/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,113 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert BERT checkpoint.""" + +import argparse +import os + +import torch + +from transformers import ( + XLNetConfig, + XLNetForQuestionAnswering, + XLNetForSequenceClassification, + XLNetLMHeadModel, + load_tf_weights_in_xlnet, +) +from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging + + +GLUE_TASKS_NUM_LABELS = { + "cola": 2, + "mnli": 3, + "mrpc": 2, + "sst-2": 2, + "sts-b": 1, + "qqp": 2, + "qnli": 2, + "rte": 2, + "wnli": 2, +} + + +logging.set_verbosity_info() + + +def convert_xlnet_checkpoint_to_pytorch( + tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None +): + # Initialise PyTorch model + config = XLNetConfig.from_json_file(bert_config_file) + + finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" + if finetuning_task in GLUE_TASKS_NUM_LABELS: + print(f"Building PyTorch XLNetForSequenceClassification model from configuration: {config}") + config.finetuning_task = finetuning_task + config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] + model = XLNetForSequenceClassification(config) + elif "squad" in finetuning_task: + config.finetuning_task = finetuning_task + model = XLNetForQuestionAnswering(config) + else: + model = XLNetLMHeadModel(config) + + # Load weights from tf checkpoint + load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) + + # Save pytorch-model + pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) + pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) + print(f"Save PyTorch model to {os.path.abspath(pytorch_weights_dump_path)}") + torch.save(model.state_dict(), pytorch_weights_dump_path) + print(f"Save configuration file to {os.path.abspath(pytorch_config_dump_path)}") + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--xlnet_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained XLNet model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the folder to store the PyTorch model or dataset/vocab.", + ) + parser.add_argument( + "--finetuning_task", + default=None, + type=str, + help="Name of a task on which the XLNet TensorFlow model was fine-tuned", + ) + args = parser.parse_args() + print(args) + + convert_xlnet_checkpoint_to_pytorch( + args.tf_checkpoint_path, args.xlnet_config_file, args.pytorch_dump_folder_path, args.finetuning_task + ) From 9b801ec54161c2f05dda9a4b3f853511a134daad Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 18 Sep 2025 17:35:25 +0200 Subject: [PATCH 34/35] explicitly add functions in converters --- ...gbird_original_tf_checkpoint_to_pytorch.py | 185 +++++++++++++++++- ..._byt5_original_tf_checkpoint_to_pytorch.py | 111 ++++++++++- ...anine_original_tf_checkpoint_to_pytorch.py | 103 +++++++++- ...bert_original_tf1_checkpoint_to_pytorch.py | 179 +++++++++++++++++ ...ginal_tf1_checkpoint_to_pytorch_and_tf2.py | 57 ------ ...fo_xl_original_tf_checkpoint_to_pytorch.py | 130 +++++++++++- ...ectra_original_tf_checkpoint_to_pytorch.py | 86 +++++++- ...unnel_original_tf_checkpoint_to_pytorch.py | 95 ++++++++- ..._gpt2_original_tf_checkpoint_to_pytorch.py | 60 +++++- .../convert_gpt_neo_mesh_tf_to_pytorch.py | 87 +++++++- ...onvert_imagegpt_original_tf2_to_pytorch.py | 112 ++++++++++- ...xmert_original_tf_checkpoint_to_pytorch.py | 83 +++++++- ...ebert_original_tf_checkpoint_to_pytorch.py | 82 +++++++- ...nvert_original_tf_checkpoint_to_pytorch.py | 104 +++++++++- ...nvert_original_tf_checkpoint_to_pytorch.py | 170 +++++++++++++++- ..._myt5_original_tf_checkpoint_to_pytorch.py | 111 ++++++++++- ...penai_original_tf_checkpoint_to_pytorch.py | 83 +++++++- ...onvert_rembert_tf_checkpoint_to_pytorch.py | 87 +++++++- ...ormer_original_tf_checkpoint_to_pytorch.py | 77 +++++++- ...rt_t5_original_tf_checkpoint_to_pytorch.py | 111 ++++++++++- ...tapas_original_tf_checkpoint_to_pytorch.py | 139 ++++++++++++- ...xlnet_original_tf_checkpoint_to_pytorch.py | 152 +++++++++++++- 22 files changed, 2327 insertions(+), 77 deletions(-) create mode 100644 src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py diff --git a/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py index 0b8e6590f937..9064c7cbdc08 100644 --- a/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py @@ -15,14 +15,197 @@ """Convert BigBird checkpoint.""" import argparse +import math +import os -from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird +import torch + +from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering from transformers.utils import logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() +_TRIVIA_QA_MAPPING = { + "big_bird_attention": "attention/self", + "output_layer_norm": "output/LayerNorm", + "attention_output": "attention/output/dense", + "output": "output/dense", + "self_attention_layer_norm": "attention/output/LayerNorm", + "intermediate": "intermediate/dense", + "word_embeddings": "bert/embeddings/word_embeddings", + "position_embedding": "bert/embeddings/position_embeddings", + "type_embeddings": "bert/embeddings/token_type_embeddings", + "embeddings": "bert/embeddings", + "layer_normalization": "output/LayerNorm", + "layer_norm": "LayerNorm", + "trivia_qa_head": "qa_classifier", + "dense": "intermediate/dense", + "dense_1": "qa_outputs", +} + + +def load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=False): + """Load tf checkpoints in a pytorch model.""" + + def load_tf_weights_bert(init_vars, tf_path): + names = [] + tf_weights = {} + + for name, shape in init_vars: + array = tf.train.load_variable(tf_path, name) + name = name.replace("bert/encoder/LayerNorm", "bert/embeddings/LayerNorm") + logger.info(f"Loading TF weight {name} with shape {shape}") + names.append(name) + tf_weights[name] = array + + return names, tf_weights + + def load_tf_weights_trivia_qa(init_vars): + names = [] + tf_weights = {} + + for i, var in enumerate(init_vars): + name_items = var.name.split("/") + + if "transformer_scaffold" in name_items[0]: + layer_name_items = name_items[0].split("_") + if len(layer_name_items) < 3: + layer_name_items += [0] + + name_items[0] = f"bert/encoder/layer_{layer_name_items[2]}" + + name = "/".join([_TRIVIA_QA_MAPPING.get(x, x) for x in name_items])[:-2] # remove last :0 in variable + + if "self/attention/output" in name: + name = name.replace("self/attention/output", "output") + + if i >= len(init_vars) - 2: + name = name.replace("intermediate", "output") + + logger.info(f"Loading TF weight {name} with shape {var.shape}") + array = var.value().numpy() + names.append(name) + tf_weights[name] = array + + return names, tf_weights + + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + + # Load weights from TF model + init_vars = tf.saved_model.load(tf_path).variables if is_trivia_qa else tf.train.list_variables(tf_path) + + if len(init_vars) <= 0: + raise ValueError("Loaded trained variables cannot be empty.") + + pt_names = list(model.state_dict().keys()) + + if is_trivia_qa: + names, tf_weights = load_tf_weights_trivia_qa(init_vars) + else: + names, tf_weights = load_tf_weights_bert(init_vars, tf_path) + + for txt_name in names: + array = tf_weights[txt_name] + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + pt_name = [] + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + pt_name.append("weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + pt_name.append("bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + pt_name.append("weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + pt_name.append("classifier") + elif scope_names[0] == "transform": + pointer = getattr(pointer, "transform") + pt_name.append("transform") + if ("bias" in name) or ("kernel" in name): + pointer = getattr(pointer, "dense") + pt_name.append("dense") + elif ("beta" in name) or ("gamma" in name): + pointer = getattr(pointer, "LayerNorm") + pt_name.append("LayerNorm") + else: + try: + pointer = getattr(pointer, scope_names[0]) + pt_name.append(f"{scope_names[0]}") + except AttributeError: + logger.info(f"Skipping {m_name}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + pt_name.append(f"{num}") + if m_name[-11:] == "_embeddings" or m_name == "embeddings": + pointer = getattr(pointer, "weight") + pt_name.append("weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if len(array.shape) > len(pointer.shape) and math.prod(array.shape) == math.prod(pointer.shape): + # print(txt_name, array.shape) + if ( + txt_name.endswith("attention/self/key/kernel") + or txt_name.endswith("attention/self/query/kernel") + or txt_name.endswith("attention/self/value/kernel") + ): + array = array.transpose(1, 0, 2).reshape(pointer.shape) + elif txt_name.endswith("attention/output/dense/kernel"): + array = array.transpose(0, 2, 1).reshape(pointer.shape) + else: + array = array.reshape(pointer.shape) + + if pointer.shape != array.shape: + raise ValueError( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched of {txt_name}." + ) + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + pt_weight_name = ".".join(pt_name) + logger.info(f"Initialize PyTorch weight {pt_weight_name} from {txt_name}.") + pointer.data = torch.from_numpy(array) + tf_weights.pop(txt_name, None) + pt_names.remove(pt_weight_name) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + logger.info(f"Weights not initialized in PyTorch model: {', '.join(pt_names)}.") + return model + + def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa): # Initialise PyTorch model config = BigBirdConfig.from_json_file(big_bird_config_file) diff --git a/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py index 9b1b15857cea..a53efce63544 100755 --- a/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py @@ -15,14 +15,123 @@ """Convert T5 checkpoint.""" import argparse +import os -from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 +import torch + +from transformers import T5Config, T5ForConditionalGeneration from transformers.utils import logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): # Initialise PyTorch model config = T5Config.from_json_file(config_file) diff --git a/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py index 45dcdb290333..5c18f64dba1c 100644 --- a/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py @@ -15,14 +15,115 @@ """Convert CANINE checkpoint.""" import argparse +import os -from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine +import torch + +from transformers import CanineConfig, CanineModel, CanineTokenizer from transformers.utils import logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() +def load_tf_weights_in_canine(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + # also discard the cls weights (which were used for the next sentence prediction pre-training task) + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + "cls", + "autoregressive_decoder", + "char_output_weights", + ] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + # if first scope name starts with "bert", change it to "encoder" + if name[0] == "bert": + name[0] = "encoder" + # remove "embeddings" middle name of HashBucketCodepointEmbedders + elif name[1] == "embeddings": + name.remove(name[1]) + # rename segment_embeddings to token_type_embeddings + elif name[1] == "segment_embeddings": + name[1] = "token_type_embeddings" + # rename initial convolutional projection layer + elif name[1] == "initial_char_encoder": + name = ["chars_to_molecules"] + name[-2:] + # rename final convolutional projection layer + elif name[0] == "final_char_encoder" and name[1] in ["LayerNorm", "conv"]: + name = ["projection"] + name[1:] + pointer = model + for m_name in name: + if (re.fullmatch(r"[A-Za-z]+_\d+", m_name)) and "Embedder" not in m_name: + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name[-10:] in [f"Embedder_{i}" for i in range(8)]: + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path): # Initialize PyTorch model config = CanineConfig() diff --git a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..26e0328ac06f --- /dev/null +++ b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ConvBERT checkpoint.""" + +import argparse +import os +import torch + +from transformers import ConvBertConfig, ConvBertModel +from transformers.utils import logging + +logger = logging.get_logger(__name__) +logging.set_verbosity_info() + +def load_tf_weights_in_convbert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + tf_data = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + tf_data[name] = array + + param_mapping = { + "embeddings.word_embeddings.weight": "electra/embeddings/word_embeddings", + "embeddings.position_embeddings.weight": "electra/embeddings/position_embeddings", + "embeddings.token_type_embeddings.weight": "electra/embeddings/token_type_embeddings", + "embeddings.LayerNorm.weight": "electra/embeddings/LayerNorm/gamma", + "embeddings.LayerNorm.bias": "electra/embeddings/LayerNorm/beta", + "embeddings_project.weight": "electra/embeddings_project/kernel", + "embeddings_project.bias": "electra/embeddings_project/bias", + } + if config.num_groups > 1: + group_dense_name = "g_dense" + else: + group_dense_name = "dense" + + for j in range(config.num_hidden_layers): + param_mapping[f"encoder.layer.{j}.attention.self.query.weight"] = ( + f"electra/encoder/layer_{j}/attention/self/query/kernel" + ) + param_mapping[f"encoder.layer.{j}.attention.self.query.bias"] = ( + f"electra/encoder/layer_{j}/attention/self/query/bias" + ) + param_mapping[f"encoder.layer.{j}.attention.self.key.weight"] = ( + f"electra/encoder/layer_{j}/attention/self/key/kernel" + ) + param_mapping[f"encoder.layer.{j}.attention.self.key.bias"] = ( + f"electra/encoder/layer_{j}/attention/self/key/bias" + ) + param_mapping[f"encoder.layer.{j}.attention.self.value.weight"] = ( + f"electra/encoder/layer_{j}/attention/self/value/kernel" + ) + param_mapping[f"encoder.layer.{j}.attention.self.value.bias"] = ( + f"electra/encoder/layer_{j}/attention/self/value/bias" + ) + param_mapping[f"encoder.layer.{j}.attention.self.key_conv_attn_layer.depthwise.weight"] = ( + f"electra/encoder/layer_{j}/attention/self/conv_attn_key/depthwise_kernel" + ) + param_mapping[f"encoder.layer.{j}.attention.self.key_conv_attn_layer.pointwise.weight"] = ( + f"electra/encoder/layer_{j}/attention/self/conv_attn_key/pointwise_kernel" + ) + param_mapping[f"encoder.layer.{j}.attention.self.key_conv_attn_layer.bias"] = ( + f"electra/encoder/layer_{j}/attention/self/conv_attn_key/bias" + ) + param_mapping[f"encoder.layer.{j}.attention.self.conv_kernel_layer.weight"] = ( + f"electra/encoder/layer_{j}/attention/self/conv_attn_kernel/kernel" + ) + param_mapping[f"encoder.layer.{j}.attention.self.conv_kernel_layer.bias"] = ( + f"electra/encoder/layer_{j}/attention/self/conv_attn_kernel/bias" + ) + param_mapping[f"encoder.layer.{j}.attention.self.conv_out_layer.weight"] = ( + f"electra/encoder/layer_{j}/attention/self/conv_attn_point/kernel" + ) + param_mapping[f"encoder.layer.{j}.attention.self.conv_out_layer.bias"] = ( + f"electra/encoder/layer_{j}/attention/self/conv_attn_point/bias" + ) + param_mapping[f"encoder.layer.{j}.attention.output.dense.weight"] = ( + f"electra/encoder/layer_{j}/attention/output/dense/kernel" + ) + param_mapping[f"encoder.layer.{j}.attention.output.LayerNorm.weight"] = ( + f"electra/encoder/layer_{j}/attention/output/LayerNorm/gamma" + ) + param_mapping[f"encoder.layer.{j}.attention.output.dense.bias"] = ( + f"electra/encoder/layer_{j}/attention/output/dense/bias" + ) + param_mapping[f"encoder.layer.{j}.attention.output.LayerNorm.bias"] = ( + f"electra/encoder/layer_{j}/attention/output/LayerNorm/beta" + ) + param_mapping[f"encoder.layer.{j}.intermediate.dense.weight"] = ( + f"electra/encoder/layer_{j}/intermediate/{group_dense_name}/kernel" + ) + param_mapping[f"encoder.layer.{j}.intermediate.dense.bias"] = ( + f"electra/encoder/layer_{j}/intermediate/{group_dense_name}/bias" + ) + param_mapping[f"encoder.layer.{j}.output.dense.weight"] = ( + f"electra/encoder/layer_{j}/output/{group_dense_name}/kernel" + ) + param_mapping[f"encoder.layer.{j}.output.dense.bias"] = ( + f"electra/encoder/layer_{j}/output/{group_dense_name}/bias" + ) + param_mapping[f"encoder.layer.{j}.output.LayerNorm.weight"] = ( + f"electra/encoder/layer_{j}/output/LayerNorm/gamma" + ) + param_mapping[f"encoder.layer.{j}.output.LayerNorm.bias"] = f"electra/encoder/layer_{j}/output/LayerNorm/beta" + + for param in model.named_parameters(): + param_name = param[0] + retriever = attrgetter(param_name) + result = retriever(model) + tf_name = param_mapping[param_name] + value = torch.from_numpy(tf_data[tf_name]) + logger.info(f"TF: {tf_name}, PT: {param_name} ") + if tf_name.endswith("/kernel"): + if not tf_name.endswith("/intermediate/g_dense/kernel"): + if not tf_name.endswith("/output/g_dense/kernel"): + value = value.T + if tf_name.endswith("/depthwise_kernel"): + value = value.permute(1, 2, 0) # 2, 0, 1 + if tf_name.endswith("/pointwise_kernel"): + value = value.permute(2, 1, 0) # 2, 1, 0 + if tf_name.endswith("/conv_attn_key/bias"): + value = value.unsqueeze(-1) + result.data = value + return model + + +def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path): + conf = ConvBertConfig.from_json_file(convbert_config_file) + model = ConvBertModel(conf) + + model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path) + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--convbert_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained ConvBERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py deleted file mode 100644 index 3d4ff779874b..000000000000 --- a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py +++ /dev/null @@ -1,57 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert ConvBERT checkpoint.""" - -import argparse - -from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert -from transformers.utils import logging - - -logging.set_verbosity_info() - - -def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path): - conf = ConvBertConfig.from_json_file(convbert_config_file) - model = ConvBertModel(conf) - - model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path) - model.save_pretrained(pytorch_dump_path) - - tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True) - tf_model.save_pretrained(pytorch_dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--convbert_config_file", - default=None, - type=str, - required=True, - help=( - "The config json file corresponding to the pre-trained ConvBERT model. \n" - "This specifies the model architecture." - ), - ) - parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - args = parser.parse_args() - convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py index 2c7b687c4d98..989a70ef71bd 100644 --- a/src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py @@ -21,12 +21,13 @@ import torch -from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl +from transformers import TransfoXLConfig, TransfoXLLMHeadModel from transformers.models.deprecated.transfo_xl import tokenization_transfo_xl as data_utils from transformers.models.deprecated.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() # We do this to be able to load python 2 datasets pickles @@ -37,6 +38,133 @@ sys.modules["vocabulary"] = data_utils +def build_tf_to_pytorch_map(model, config): + """ + A map of modules from TF to PyTorch. This time I use a map to keep the PyTorch model as identical to the original + PyTorch model as possible. + """ + tf_to_pt_map = {} + + if hasattr(model, "transformer"): + # We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax + tf_to_pt_map.update( + { + "transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight, + "transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias, + } + ) + for i, (out_l, proj_l, tie_proj) in enumerate( + zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs) + ): + layer_str = f"transformer/adaptive_softmax/cutoff_{i}/" + if config.tie_word_embeddings: + tf_to_pt_map.update({layer_str + "b": out_l.bias}) + else: + raise NotImplementedError + # I don't think this is implemented in the TF code + tf_to_pt_map.update({layer_str + "lookup_table": out_l.weight, layer_str + "b": out_l.bias}) + if not tie_proj: + tf_to_pt_map.update({layer_str + "proj": proj_l}) + # Now load the rest of the transformer + model = model.transformer + + # Embeddings + for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)): + layer_str = f"transformer/adaptive_embed/cutoff_{i}/" + tf_to_pt_map.update({layer_str + "lookup_table": embed_l.weight, layer_str + "proj_W": proj_l}) + + # Transformer blocks + for i, b in enumerate(model.layers): + layer_str = f"transformer/layer_{i}/" + tf_to_pt_map.update( + { + layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight, + layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias, + layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight, + layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight, + layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight, + layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight, + layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias, + layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight, + layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias, + layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight, + layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias, + } + ) + + # Relative positioning biases + if config.untie_r: + r_r_list = [] + r_w_list = [] + for b in model.layers: + r_r_list.append(b.dec_attn.r_r_bias) + r_w_list.append(b.dec_attn.r_w_bias) + else: + r_r_list = [model.r_r_bias] + r_w_list = [model.r_w_bias] + tf_to_pt_map.update({"transformer/r_r_bias": r_r_list, "transformer/r_w_bias": r_w_list}) + return tf_to_pt_map + + +def load_tf_weights_in_transfo_xl(model, config, tf_path): + """Load tf checkpoints in a pytorch model""" + try: + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + # Build TF to PyTorch weights loading map + tf_to_pt_map = build_tf_to_pytorch_map(model, config) + + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + tf_weights[name] = array + + for name, pointer in tf_to_pt_map.items(): + assert name in tf_weights + array = tf_weights[name] + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if "kernel" in name or "proj" in name: + array = np.transpose(array) + if ("r_r_bias" in name or "r_w_bias" in name) and len(pointer) > 1: + # Here we will split the TF weights + assert len(pointer) == array.shape[0] + for i, p_i in enumerate(pointer): + arr_i = array[i, ...] + try: + assert p_i.shape == arr_i.shape + except AssertionError as e: + e.args += (p_i.shape, arr_i.shape) + raise + logger.info(f"Initialize PyTorch weight {name} for layer {i}") + p_i.data = torch.from_numpy(arr_i) + else: + try: + assert pointer.shape == array.shape, ( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + ) + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + tf_weights.pop(name, None) + tf_weights.pop(name + "/Adam", None) + tf_weights.pop(name + "/Adam_1", None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}") + return model + + def convert_transfo_xl_checkpoint_to_pytorch( tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file ): diff --git a/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py index b0abc30cd758..00d6fecc21b0 100644 --- a/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py @@ -15,16 +15,100 @@ """Convert ELECTRA checkpoint.""" import argparse +import os import torch -from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra +from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining from transformers.utils import logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() +def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + for name, array in zip(names, arrays): + original_name: str = name + + try: + if isinstance(model, ElectraForMaskedLM): + name = name.replace("electra/embeddings/", "generator/embeddings/") + + if discriminator_or_generator == "generator": + name = name.replace("electra/", "discriminator/") + name = name.replace("generator/", "electra/") + + name = name.replace("dense_1", "dense_prediction") + name = name.replace("generator_predictions/output_bias", "generator_lm_head/bias") + + name = name.split("/") + # print(original_name, name) + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ["global_step", "temperature"] for n in name): + logger.info(f"Skipping {original_name}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name.endswith("_embeddings"): + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + print(f"Initialize PyTorch weight {name}", original_name) + pointer.data = torch.from_numpy(array) + except AttributeError as e: + print(f"Skipping {original_name}", name, e) + continue + return model + + def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator): # Initialise PyTorch model config = ElectraConfig.from_json_file(config_file) diff --git a/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py index 4eab188f2ab7..25f7483732da 100755 --- a/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py @@ -15,16 +15,109 @@ """Convert Funnel checkpoint.""" import argparse +import os import torch -from transformers import FunnelBaseModel, FunnelConfig, FunnelModel, load_tf_weights_in_funnel +from transformers import FunnelBaseModel, FunnelConfig, FunnelModel +from transformers.models.funnel.modeling_funnel import FunnelPositionwiseFFN, FunnelRelMultiheadAttention from transformers.utils import logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() +def load_tf_weights_in_funnel(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + _layer_map = { + "k": "k_head", + "q": "q_head", + "v": "v_head", + "o": "post_proj", + "layer_1": "linear_1", + "layer_2": "linear_2", + "rel_attn": "attention", + "ff": "ffn", + "kernel": "weight", + "gamma": "weight", + "beta": "bias", + "lookup_table": "weight", + "word_embedding": "word_embeddings", + "input": "embeddings", + } + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + if name[0] == "generator": + continue + pointer = model + skipped = False + for m_name in name[1:]: + if not isinstance(pointer, FunnelPositionwiseFFN) and re.fullmatch(r"layer_\d+", m_name): + layer_index = int(re.search(r"layer_(\d+)", m_name).groups()[0]) + if layer_index < config.num_hidden_layers: + block_idx = 0 + while layer_index >= config.block_sizes[block_idx]: + layer_index -= config.block_sizes[block_idx] + block_idx += 1 + pointer = pointer.blocks[block_idx][layer_index] + else: + layer_index -= config.num_hidden_layers + pointer = pointer.layers[layer_index] + elif m_name == "r" and isinstance(pointer, FunnelRelMultiheadAttention): + pointer = pointer.r_kernel + break + elif m_name in _layer_map: + pointer = getattr(pointer, _layer_map[m_name]) + else: + try: + pointer = getattr(pointer, m_name) + except AttributeError: + print(f"Skipping {'/'.join(name)}", array.shape) + skipped = True + break + if not skipped: + if len(pointer.shape) != len(array.shape): + array = array.reshape(pointer.shape) + if m_name == "kernel": + array = np.transpose(array) + pointer.data = torch.from_numpy(array) + + return model + + def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, base_model): # Initialise PyTorch model config = FunnelConfig.from_json_file(config_file) diff --git a/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py index 33f9dabed07f..8fba497c49a8 100755 --- a/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py @@ -15,14 +15,72 @@ """Convert OpenAI GPT checkpoint.""" import argparse +import os import torch -from transformers import GPT2Config, GPT2Model, load_tf_weights_in_gpt2 +from transformers import GPT2Config, GPT2Model from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): diff --git a/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py b/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py index 3db22857293c..6c52a515b6c4 100644 --- a/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py +++ b/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py @@ -16,12 +16,97 @@ import argparse import json +import os -from transformers import GPTNeoConfig, GPTNeoForCausalLM, load_tf_weights_in_gpt_neo +import torch +import torch.nn as nn + +from transformers import GPTNeoConfig, GPTNeoForCausalLM from transformers.utils import logging logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt_neo_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + if "global_step" not in name and "adam" not in name: + array = tf.train.load_variable(tf_path, name) + array = tf.dtypes.cast(array.squeeze(), tf.float32).numpy() + name = name.replace("attn/q", "attn/attention/q_proj/w") + name = name.replace("attn/k", "attn/attention/k_proj/w") + name = name.replace("attn/v", "attn/attention/v_proj/w") + name = name.replace("attn/o", "attn/attention/out_proj/w") + name = name.replace("norm_1", "ln_1") + name = name.replace("norm_2", "ln_2") + name = name.replace("attn/compute_output_bias/o_b", "attn/attention/out_proj/b") + name = name.replace("conv1d_main/c_fc/kernel", "c_fc/w") + name = name.replace("conv1d_main/c_fc/bias", "c_fc/b") + name = name.replace("conv1d_main/c_proj/kernel", "c_proj/w") + name = name.replace("conv1d_main/c_proj/bias", "c_proj/b") + + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name[5:] # skip "gpt2/" + name = name.split("/") + pointer = model.transformer + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + + if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]: + array = array.transpose() + + if name == ["wte"]: + # if vocab is padded, then trim off the padding embeddings + array = array[: config.vocab_size] + + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched {name}") + + print(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + + # init the final linear layer using word embeddings + embs = model.transformer.wte.weight + lin = nn.Linear(embs.size()[1], embs.size()[0], bias=False) + lin.weight = embs + model.set_output_embeddings(lin) + return model def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): diff --git a/src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py b/src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py index 182d66b9af28..a1bb2efee2e1 100644 --- a/src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py +++ b/src/transformers/models/imagegpt/convert_imagegpt_original_tf2_to_pytorch.py @@ -15,14 +15,124 @@ """Convert OpenAI Image GPT checkpoints.""" import argparse +import os import torch -from transformers import ImageGPTConfig, ImageGPTForCausalLM, load_tf_weights_in_imagegpt +from transformers import ImageGPTConfig, ImageGPTForCausalLM from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path): + """ + Load tf checkpoints in a pytorch model + """ + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(imagegpt_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ) or name[-1] in ["_step"]: + logger.info("Skipping {}".format("/".join(name))) + continue + + pointer = model + if name[-1] not in ["wtet"]: + pointer = getattr(pointer, "transformer") + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + elif scope_names[0] in ["q_proj", "k_proj", "v_proj"]: + pointer = getattr(pointer, "c_attn") + pointer = getattr(pointer, "weight") + elif len(name) == 3 and name[1] == "attn" and scope_names[0] == "c_proj": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + elif scope_names[0] == "wtet": + pointer = getattr(pointer, "lm_head") + pointer = getattr(pointer, "weight") + elif scope_names[0] == "sos": + pointer = getattr(pointer, "wte") + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + + if len(name) > 1 and name[1] == "attn" or name[-1] == "wtet" or name[-1] == "sos" or name[-1] == "wte": + pass # array is used to initialize only part of the pointer so sizes won't match + else: + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + + logger.info(f"Initialize PyTorch weight {name}") + + if name[-1] == "q_proj": + pointer.data[:, : config.n_embd] = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)).T + elif name[-1] == "k_proj": + pointer.data[:, config.n_embd : 2 * config.n_embd] = torch.from_numpy( + array.reshape(config.n_embd, config.n_embd) + ).T + elif name[-1] == "v_proj": + pointer.data[:, 2 * config.n_embd :] = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)).T + elif len(name) == 3 and name[1] == "attn" and name[2] == "c_proj": + pointer.data = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)) + elif name[-1] == "wtet": + pointer.data = torch.from_numpy(array) + elif name[-1] == "wte": + pointer.data[: config.vocab_size - 1, :] = torch.from_numpy(array) + elif name[-1] == "sos": + pointer.data[-1] = torch.from_numpy(array) + else: + pointer.data = torch.from_numpy(array) + + return model def convert_imagegpt_checkpoint_to_pytorch(imagegpt_checkpoint_path, model_size, pytorch_dump_folder_path): diff --git a/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py index 1dd77bc36f80..bf93a1cad190 100755 --- a/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py @@ -15,16 +15,97 @@ """Convert LXMERT checkpoint.""" import argparse +import os import torch -from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert +from transformers import LxmertConfig, LxmertForPreTraining from transformers.utils import logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() +def load_tf_weights_in_lxmert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): # Initialise PyTorch model config = LxmertConfig.from_json_file(config_file) diff --git a/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py index 022a9d036cdb..53288953d81e 100644 --- a/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py @@ -13,14 +13,94 @@ # limitations under the License. import argparse +import os import torch -from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert +from transformers import MobileBertConfig, MobileBertForPreTraining from transformers.utils import logging logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.replace("ffn_layer", "ffn") + name = name.replace("FakeLayerNorm", "LayerNorm") + name = name.replace("extra_output_weights", "dense/kernel") + name = name.replace("bert", "mobilebert") + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert pointer.shape == array.shape, ( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + ) + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path): diff --git a/src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py index 1b53bbeab475..b0eeb2874aa6 100644 --- a/src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/mobilenet_v1/convert_original_tf_checkpoint_to_pytorch.py @@ -28,7 +28,6 @@ MobileNetV1Config, MobileNetV1ForImageClassification, MobileNetV1ImageProcessor, - load_tf_weights_in_mobilenet_v1, ) from transformers.utils import logging @@ -37,6 +36,109 @@ logger = logging.get_logger(__name__) +def _build_tf_to_pytorch_map(model, config, tf_weights=None): + """ + A map of modules from TF to PyTorch. + """ + + tf_to_pt_map = {} + + if isinstance(model, MobileNetV1ForImageClassification): + backbone = model.mobilenet_v1 + else: + backbone = model + + prefix = "MobilenetV1/Conv2d_0/" + tf_to_pt_map[prefix + "weights"] = backbone.conv_stem.convolution.weight + tf_to_pt_map[prefix + "BatchNorm/beta"] = backbone.conv_stem.normalization.bias + tf_to_pt_map[prefix + "BatchNorm/gamma"] = backbone.conv_stem.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.normalization.running_var + + for i in range(13): + tf_index = i + 1 + pt_index = i * 2 + + pointer = backbone.layer[pt_index] + prefix = f"MobilenetV1/Conv2d_{tf_index}_depthwise/" + tf_to_pt_map[prefix + "depthwise_weights"] = pointer.convolution.weight + tf_to_pt_map[prefix + "BatchNorm/beta"] = pointer.normalization.bias + tf_to_pt_map[prefix + "BatchNorm/gamma"] = pointer.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.normalization.running_var + + pointer = backbone.layer[pt_index + 1] + prefix = f"MobilenetV1/Conv2d_{tf_index}_pointwise/" + tf_to_pt_map[prefix + "weights"] = pointer.convolution.weight + tf_to_pt_map[prefix + "BatchNorm/beta"] = pointer.normalization.bias + tf_to_pt_map[prefix + "BatchNorm/gamma"] = pointer.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.normalization.running_var + + if isinstance(model, MobileNetV1ForImageClassification): + prefix = "MobilenetV1/Logits/Conv2d_1c_1x1/" + tf_to_pt_map[prefix + "weights"] = model.classifier.weight + tf_to_pt_map[prefix + "biases"] = model.classifier.bias + + return tf_to_pt_map + + +def load_tf_weights_in_mobilenet_v1(model, config, tf_checkpoint_path): + """Load TensorFlow checkpoints in a PyTorch model.""" + try: + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + + # Load weights from TF model + init_vars = tf.train.list_variables(tf_checkpoint_path) + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_checkpoint_path, name) + tf_weights[name] = array + + # Build TF to PyTorch weights loading map + tf_to_pt_map = _build_tf_to_pytorch_map(model, config, tf_weights) + + for name, pointer in tf_to_pt_map.items(): + logger.info(f"Importing {name}") + if name not in tf_weights: + logger.info(f"{name} not in tf pre-trained weights, skipping") + continue + + array = tf_weights[name] + + if "depthwise_weights" in name: + logger.info("Transposing depthwise") + array = np.transpose(array, (2, 3, 0, 1)) + elif "weights" in name: + logger.info("Transposing") + if len(pointer.shape) == 2: # copying into linear layer + array = array.squeeze().transpose() + else: + array = np.transpose(array, (3, 2, 0, 1)) + + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + + logger.info(f"Initialize PyTorch weight {name} {array.shape}") + pointer.data = torch.from_numpy(array) + + tf_weights.pop(name, None) + tf_weights.pop(name + "/RMSProp", None) + tf_weights.pop(name + "/RMSProp_1", None) + tf_weights.pop(name + "/ExponentialMovingAverage", None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}") + return model + + def get_mobilenet_v1_config(model_name): config = MobileNetV1Config(layer_norm_eps=0.001) diff --git a/src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py index 1fdb9783ccf0..6f94b074b440 100644 --- a/src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/mobilenet_v2/convert_original_tf_checkpoint_to_pytorch.py @@ -29,7 +29,6 @@ MobileNetV2ForImageClassification, MobileNetV2ForSemanticSegmentation, MobileNetV2ImageProcessor, - load_tf_weights_in_mobilenet_v2, ) from transformers.utils import logging @@ -38,6 +37,175 @@ logger = logging.get_logger(__name__) +def _build_tf_to_pytorch_map(model, config, tf_weights=None): + """ + A map of modules from TF to PyTorch. + """ + + tf_to_pt_map = {} + + if isinstance(model, (MobileNetV2ForImageClassification, MobileNetV2ForSemanticSegmentation)): + backbone = model.mobilenet_v2 + else: + backbone = model + + # Use the EMA weights if available + def ema(x): + return x + "/ExponentialMovingAverage" if x + "/ExponentialMovingAverage" in tf_weights else x + + prefix = "MobilenetV2/Conv/" + tf_to_pt_map[ema(prefix + "weights")] = backbone.conv_stem.first_conv.convolution.weight + tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_stem.first_conv.normalization.bias + tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_stem.first_conv.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.first_conv.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.first_conv.normalization.running_var + + prefix = "MobilenetV2/expanded_conv/depthwise/" + tf_to_pt_map[ema(prefix + "depthwise_weights")] = backbone.conv_stem.conv_3x3.convolution.weight + tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_stem.conv_3x3.normalization.bias + tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_stem.conv_3x3.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.conv_3x3.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.conv_3x3.normalization.running_var + + prefix = "MobilenetV2/expanded_conv/project/" + tf_to_pt_map[ema(prefix + "weights")] = backbone.conv_stem.reduce_1x1.convolution.weight + tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_stem.reduce_1x1.normalization.bias + tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_stem.reduce_1x1.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.reduce_1x1.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.reduce_1x1.normalization.running_var + + for i in range(16): + tf_index = i + 1 + pt_index = i + pointer = backbone.layer[pt_index] + + prefix = f"MobilenetV2/expanded_conv_{tf_index}/expand/" + tf_to_pt_map[ema(prefix + "weights")] = pointer.expand_1x1.convolution.weight + tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.expand_1x1.normalization.bias + tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.expand_1x1.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.expand_1x1.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.expand_1x1.normalization.running_var + + prefix = f"MobilenetV2/expanded_conv_{tf_index}/depthwise/" + tf_to_pt_map[ema(prefix + "depthwise_weights")] = pointer.conv_3x3.convolution.weight + tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.conv_3x3.normalization.bias + tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.conv_3x3.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.conv_3x3.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.conv_3x3.normalization.running_var + + prefix = f"MobilenetV2/expanded_conv_{tf_index}/project/" + tf_to_pt_map[ema(prefix + "weights")] = pointer.reduce_1x1.convolution.weight + tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.reduce_1x1.normalization.bias + tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.reduce_1x1.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.reduce_1x1.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.reduce_1x1.normalization.running_var + + prefix = "MobilenetV2/Conv_1/" + tf_to_pt_map[ema(prefix + "weights")] = backbone.conv_1x1.convolution.weight + tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_1x1.normalization.bias + tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_1x1.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_1x1.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_1x1.normalization.running_var + + if isinstance(model, MobileNetV2ForImageClassification): + prefix = "MobilenetV2/Logits/Conv2d_1c_1x1/" + tf_to_pt_map[ema(prefix + "weights")] = model.classifier.weight + tf_to_pt_map[ema(prefix + "biases")] = model.classifier.bias + + if isinstance(model, MobileNetV2ForSemanticSegmentation): + prefix = "image_pooling/" + tf_to_pt_map[prefix + "weights"] = model.segmentation_head.conv_pool.convolution.weight + tf_to_pt_map[prefix + "BatchNorm/beta"] = model.segmentation_head.conv_pool.normalization.bias + tf_to_pt_map[prefix + "BatchNorm/gamma"] = model.segmentation_head.conv_pool.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = model.segmentation_head.conv_pool.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = ( + model.segmentation_head.conv_pool.normalization.running_var + ) + + prefix = "aspp0/" + tf_to_pt_map[prefix + "weights"] = model.segmentation_head.conv_aspp.convolution.weight + tf_to_pt_map[prefix + "BatchNorm/beta"] = model.segmentation_head.conv_aspp.normalization.bias + tf_to_pt_map[prefix + "BatchNorm/gamma"] = model.segmentation_head.conv_aspp.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = model.segmentation_head.conv_aspp.normalization.running_mean + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = ( + model.segmentation_head.conv_aspp.normalization.running_var + ) + + prefix = "concat_projection/" + tf_to_pt_map[prefix + "weights"] = model.segmentation_head.conv_projection.convolution.weight + tf_to_pt_map[prefix + "BatchNorm/beta"] = model.segmentation_head.conv_projection.normalization.bias + tf_to_pt_map[prefix + "BatchNorm/gamma"] = model.segmentation_head.conv_projection.normalization.weight + tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = ( + model.segmentation_head.conv_projection.normalization.running_mean + ) + tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = ( + model.segmentation_head.conv_projection.normalization.running_var + ) + + prefix = "logits/semantic/" + tf_to_pt_map[ema(prefix + "weights")] = model.segmentation_head.classifier.convolution.weight + tf_to_pt_map[ema(prefix + "biases")] = model.segmentation_head.classifier.convolution.bias + + return tf_to_pt_map + + +def load_tf_weights_in_mobilenet_v2(model, config, tf_checkpoint_path): + """Load TensorFlow checkpoints in a PyTorch model.""" + try: + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + + # Load weights from TF model + init_vars = tf.train.list_variables(tf_checkpoint_path) + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_checkpoint_path, name) + tf_weights[name] = array + + # Build TF to PyTorch weights loading map + tf_to_pt_map = _build_tf_to_pytorch_map(model, config, tf_weights) + + for name, pointer in tf_to_pt_map.items(): + logger.info(f"Importing {name}") + if name not in tf_weights: + logger.info(f"{name} not in tf pre-trained weights, skipping") + continue + + array = tf_weights[name] + + if "depthwise_weights" in name: + logger.info("Transposing depthwise") + array = np.transpose(array, (2, 3, 0, 1)) + elif "weights" in name: + logger.info("Transposing") + if len(pointer.shape) == 2: # copying into linear layer + array = array.squeeze().transpose() + else: + array = np.transpose(array, (3, 2, 0, 1)) + + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + + logger.info(f"Initialize PyTorch weight {name} {array.shape}") + pointer.data = torch.from_numpy(array) + + tf_weights.pop(name, None) + tf_weights.pop(name + "/RMSProp", None) + tf_weights.pop(name + "/RMSProp_1", None) + tf_weights.pop(name + "/ExponentialMovingAverage", None) + tf_weights.pop(name + "/Momentum", None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}") + return model + + def get_mobilenet_v2_config(model_name): config = MobileNetV2Config(layer_norm_eps=0.001) diff --git a/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py index 39653e4b1c77..0f8a17d1adce 100644 --- a/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py @@ -15,14 +15,123 @@ """Convert MyT5 checkpoint.""" import argparse +import os -from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 +import torch + +from transformers import T5Config, T5ForConditionalGeneration from transformers.utils import logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + # Copied from transformers.models.t5.convert_t5_original_tf_checkpoint_to_pytorch.convert_tf_checkpoint_to_pytorch def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): # Initialise PyTorch model diff --git a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py index 3d5218c20426..df8dddce9828 100755 --- a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py @@ -15,14 +15,95 @@ """Convert OpenAI GPT checkpoint.""" import argparse +import json +import os import torch -from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt +from transformers import OpenAIGPTConfig, OpenAIGPTModel from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): + """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)""" + import re + + import numpy as np + + if ".ckpt" in openai_checkpoint_folder_path: + openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path) + + logger.info(f"Loading weights from {openai_checkpoint_folder_path}") + + with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle: + names = json.load(names_handle) + with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle: + shapes = json.load(shapes_handle) + offsets = np.cumsum([np.prod(shape) for shape in shapes]) + init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)] + init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] + init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] + + # This was used when we had a single embedding matrix for positions and tokens + # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0) + # del init_params[1] + init_params = [arr.squeeze() for arr in init_params] + + # Check that the token and position embeddings weight dimensions map those of the init parameters. + if model.tokens_embed.weight.shape != init_params[1].shape: + raise ValueError( + f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:" + f" {init_params[1].shape}" + ) + + if model.positions_embed.weight.shape != init_params[0].shape: + raise ValueError( + f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:" + f" {init_params[0].shape}" + ) + + model.tokens_embed.weight.data = torch.from_numpy(init_params[1]) + model.positions_embed.weight.data = torch.from_numpy(init_params[0]) + names.pop(0) + # Pop position and token embedding arrays + init_params.pop(0) + init_params.pop(0) + + for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]): + name = name[6:] # skip "model/" + if name[-2:] != ":0": + raise ValueError(f"Layer {name} does not end with :0") + name = name[:-2] + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "w": + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + + # Ensure that the pointer and array have compatible shapes. + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): diff --git a/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py b/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py index 369388c540f9..7964ba9fb275 100755 --- a/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py @@ -15,14 +15,99 @@ """Convert RemBERT checkpoint.""" import argparse +import os import torch -from transformers import RemBertConfig, RemBertModel, load_tf_weights_in_rembert +from transformers import RemBertConfig, RemBertModel from transformers.utils import logging logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def load_tf_weights_in_rembert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + # Checkpoint is 12Gb, save memory by not loading useless variables + # Output embedding and cls are reset at classification time + if any(deny in name for deny in ("adam_v", "adam_m", "output_embedding", "cls")): + # logger.info("Skipping loading of %s", name) + continue + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + # Replace prefix with right one + name = name.replace("bert/", "rembert/") + # The pooler is a linear layer + # name = name.replace("pooler/dense", "pooler") + + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model def convert_rembert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): diff --git a/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py index d227948e0ee3..f68152643da8 100755 --- a/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py @@ -15,16 +15,91 @@ """Convert RoFormer checkpoint.""" import argparse +import os import torch -from transformers import RoFormerConfig, RoFormerForMaskedLM, load_tf_weights_in_roformer +from transformers import RoFormerConfig, RoFormerForMaskedLM from transformers.utils import logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() +def load_tf_weights_in_roformer(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name.replace("bert", "roformer")) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if not pointer.shape == array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): # Initialise PyTorch model config = RoFormerConfig.from_json_file(bert_config_file) diff --git a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py index 9b1b15857cea..a53efce63544 100755 --- a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py @@ -15,14 +15,123 @@ """Convert T5 checkpoint.""" import argparse +import os -from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 +import torch + +from transformers import T5Config, T5ForConditionalGeneration from transformers.utils import logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): # Initialise PyTorch model config = T5Config.from_json_file(config_file) diff --git a/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py index 34bf77cccd6b..d9400b366e5f 100644 --- a/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py @@ -15,6 +15,9 @@ """Convert TAPAS checkpoint.""" import argparse +import os + +import torch from transformers import ( TapasConfig, @@ -23,14 +26,148 @@ TapasForSequenceClassification, TapasModel, TapasTokenizer, - load_tf_weights_in_tapas, ) from transformers.utils import logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() +def load_tf_weights_in_tapas(model, config, tf_checkpoint_path): + """ + Load tf checkpoints in a PyTorch model. This is an adaptation from load_tf_weights_in_bert + + - add cell selection and aggregation heads + - take into account additional token type embedding layers + """ + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculate m and v + # which are not required for using pretrained model + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + "seq_relationship", + ] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + # in case the model is TapasForSequenceClassification, we skip output_bias and output_weights + # since these are not used for classification + if isinstance(model, TapasForSequenceClassification): + if any(n in ["output_bias", "output_weights"] for n in name): + logger.info(f"Skipping {'/'.join(name)}") + continue + # in case the model is TapasModel, we skip output_bias, output_weights, output_bias_cls and output_weights_cls + # since this model does not have MLM and NSP heads + if isinstance(model, TapasModel): + if any(n in ["output_bias", "output_weights", "output_bias_cls", "output_weights_cls"] for n in name): + logger.info(f"Skipping {'/'.join(name)}") + continue + # in case the model is TapasForMaskedLM, we skip the pooler + if isinstance(model, TapasForMaskedLM): + if any(n in ["pooler"] for n in name): + logger.info(f"Skipping {'/'.join(name)}") + continue + # if first scope name starts with "bert", change it to "tapas" + if name[0] == "bert": + name[0] = "tapas" + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + # cell selection heads + elif scope_names[0] == "output_bias": + if not isinstance(model, TapasForMaskedLM): + pointer = getattr(pointer, "output_bias") + else: + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "output_weights") + elif scope_names[0] == "column_output_bias": + pointer = getattr(pointer, "column_output_bias") + elif scope_names[0] == "column_output_weights": + pointer = getattr(pointer, "column_output_weights") + # aggregation head + elif scope_names[0] == "output_bias_agg": + pointer = getattr(pointer, "aggregation_classifier") + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights_agg": + pointer = getattr(pointer, "aggregation_classifier") + pointer = getattr(pointer, "weight") + # classification head + elif scope_names[0] == "output_bias_cls": + pointer = getattr(pointer, "classifier") + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights_cls": + pointer = getattr(pointer, "classifier") + pointer = getattr(pointer, "weight") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name[-13:] in [f"_embeddings_{i}" for i in range(7)]: + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + # Added a check to see whether the array is a scalar (because bias terms in Tapas checkpoints can be + # scalar => should first be converted to numpy arrays) + if np.isscalar(array): + array = np.array(array) + pointer.data = torch.from_numpy(array) + return model + + def convert_tf_checkpoint_to_pytorch( task, reset_position_index_per_cell, tf_checkpoint_path, tapas_config_file, pytorch_dump_path ): diff --git a/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py index a15c5f22ad68..81aef230ac43 100755 --- a/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py @@ -24,7 +24,6 @@ XLNetForQuestionAnswering, XLNetForSequenceClassification, XLNetLMHeadModel, - load_tf_weights_in_xlnet, ) from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging @@ -43,6 +42,157 @@ logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None): + """ + A map of modules from TF to PyTorch. I use a map to keep the PyTorch model as identical to the original PyTorch + model as possible. + """ + + tf_to_pt_map = {} + + if hasattr(model, "transformer"): + if hasattr(model, "lm_loss"): + # We will load also the output bias + tf_to_pt_map["model/lm_loss/bias"] = model.lm_loss.bias + if hasattr(model, "sequence_summary") and "model/sequnece_summary/summary/kernel" in tf_weights: + # We will load also the sequence summary + tf_to_pt_map["model/sequnece_summary/summary/kernel"] = model.sequence_summary.summary.weight + tf_to_pt_map["model/sequnece_summary/summary/bias"] = model.sequence_summary.summary.bias + if ( + hasattr(model, "logits_proj") + and config.finetuning_task is not None + and f"model/regression_{config.finetuning_task}/logit/kernel" in tf_weights + ): + tf_to_pt_map[f"model/regression_{config.finetuning_task}/logit/kernel"] = model.logits_proj.weight + tf_to_pt_map[f"model/regression_{config.finetuning_task}/logit/bias"] = model.logits_proj.bias + + # Now load the rest of the transformer + model = model.transformer + + # Embeddings and output + tf_to_pt_map.update( + { + "model/transformer/word_embedding/lookup_table": model.word_embedding.weight, + "model/transformer/mask_emb/mask_emb": model.mask_emb, + } + ) + + # Transformer blocks + for i, b in enumerate(model.layer): + layer_str = f"model/transformer/layer_{i}/" + tf_to_pt_map.update( + { + layer_str + "rel_attn/LayerNorm/gamma": b.rel_attn.layer_norm.weight, + layer_str + "rel_attn/LayerNorm/beta": b.rel_attn.layer_norm.bias, + layer_str + "rel_attn/o/kernel": b.rel_attn.o, + layer_str + "rel_attn/q/kernel": b.rel_attn.q, + layer_str + "rel_attn/k/kernel": b.rel_attn.k, + layer_str + "rel_attn/r/kernel": b.rel_attn.r, + layer_str + "rel_attn/v/kernel": b.rel_attn.v, + layer_str + "ff/LayerNorm/gamma": b.ff.layer_norm.weight, + layer_str + "ff/LayerNorm/beta": b.ff.layer_norm.bias, + layer_str + "ff/layer_1/kernel": b.ff.layer_1.weight, + layer_str + "ff/layer_1/bias": b.ff.layer_1.bias, + layer_str + "ff/layer_2/kernel": b.ff.layer_2.weight, + layer_str + "ff/layer_2/bias": b.ff.layer_2.bias, + } + ) + + # Relative positioning biases + if config.untie_r: + r_r_list = [] + r_w_list = [] + r_s_list = [] + seg_embed_list = [] + for b in model.layer: + r_r_list.append(b.rel_attn.r_r_bias) + r_w_list.append(b.rel_attn.r_w_bias) + r_s_list.append(b.rel_attn.r_s_bias) + seg_embed_list.append(b.rel_attn.seg_embed) + else: + r_r_list = [model.r_r_bias] + r_w_list = [model.r_w_bias] + r_s_list = [model.r_s_bias] + seg_embed_list = [model.seg_embed] + tf_to_pt_map.update( + { + "model/transformer/r_r_bias": r_r_list, + "model/transformer/r_w_bias": r_w_list, + "model/transformer/r_s_bias": r_s_list, + "model/transformer/seg_embed": seg_embed_list, + } + ) + return tf_to_pt_map + + +def load_tf_weights_in_xlnet(model, config, tf_path): + """Load tf checkpoints in a pytorch model""" + try: + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + tf_weights[name] = array + + # Build TF to PyTorch weights loading map + tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights) + + for name, pointer in tf_to_pt_map.items(): + logger.info(f"Importing {name}") + if name not in tf_weights: + logger.info(f"{name} not in tf pre-trained weights, skipping") + continue + array = tf_weights[name] + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if "kernel" in name and ("ff" in name or "summary" in name or "logit" in name): + logger.info("Transposing") + array = np.transpose(array) + if isinstance(pointer, list): + # Here we will split the TF weights + assert len(pointer) == array.shape[0], ( + f"Pointer length {len(pointer)} and array length {array.shape[0]} mismatched" + ) + for i, p_i in enumerate(pointer): + arr_i = array[i, ...] + try: + assert p_i.shape == arr_i.shape, ( + f"Pointer shape {p_i.shape} and array shape {arr_i.shape} mismatched" + ) + except AssertionError as e: + e.args += (p_i.shape, arr_i.shape) + raise + logger.info(f"Initialize PyTorch weight {name} for layer {i}") + p_i.data = torch.from_numpy(arr_i) + else: + try: + assert pointer.shape == array.shape, ( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + ) + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + tf_weights.pop(name, None) + tf_weights.pop(name + "/Adam", None) + tf_weights.pop(name + "/Adam_1", None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}") + return model def convert_xlnet_checkpoint_to_pytorch( From df215f9979fb808a8bdc723092cba3e79144da4b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 18 Sep 2025 17:39:35 +0200 Subject: [PATCH 35/35] re-add --- ...bert_original_tf1_checkpoint_to_pytorch.py | 4 + ...ers_original_flax_checkpoint_to_pytorch.py | 145 +++++++++++++++++- 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py index 26e0328ac06f..350e5a7f3f90 100644 --- a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py +++ b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py @@ -16,14 +16,18 @@ import argparse import os +from operator import attrgetter + import torch from transformers import ConvBertConfig, ConvBertModel from transformers.utils import logging + logger = logging.get_logger(__name__) logging.set_verbosity_info() + def load_tf_weights_in_convbert(model, config, tf_checkpoint_path): """Load tf checkpoints in a pytorch model.""" try: diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index 71d304ea96c6..e73a1f7181ba 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -18,17 +18,160 @@ import argparse import re +import jax +import jax.numpy as jnp +import numpy as np from flax.traverse_util import flatten_dict, unflatten_dict from t5x import checkpoints from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration -from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.utils import logging +logger = logging.get_logger(__name__) logging.set_verbosity_info() +def load_flax_weights_in_pytorch_model(pt_model, flax_state): + """Load flax checkpoints in a PyTorch model""" + + try: + import torch # noqa: F401 + except (ImportError, ModuleNotFoundError): + logger.error( + "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/index.html#installation for installation" + " instructions." + ) + raise + + # check if we have bf16 weights + is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() + if any(is_type_bf16): + # convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16 + # and bf16 is not fully supported in PT yet. + logger.warning( + "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " + "before loading those in PyTorch model." + ) + flax_state = jax.tree_util.tree_map( + lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state + ) + + flax_state_dict = flatten_dict(flax_state) + pt_model_dict = pt_model.state_dict() + + load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and ( + pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict} + ) + load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and ( + pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict} + ) + + # keep track of unexpected & missing keys + unexpected_keys = [] + missing_keys = set(pt_model_dict.keys()) + + for flax_key_tuple, flax_tensor in flax_state_dict.items(): + has_base_model_prefix = flax_key_tuple[0] == pt_model.base_model_prefix + require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict + + # adapt flax_key to prepare for loading from/to base model only + if load_model_with_head_into_base_model and has_base_model_prefix: + flax_key_tuple = flax_key_tuple[1:] + elif load_base_model_into_model_with_head and require_base_model_prefix: + flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple + + # rename flax weights to PyTorch format + if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 4 and ".".join(flax_key_tuple) not in pt_model_dict: + # conv layer + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) + elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict: + # linear layer + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + flax_tensor = flax_tensor.T + elif flax_key_tuple[-1] in ["scale", "embedding"]: + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + + # adding batch stats from flax batch norm to pt + elif "mean" in flax_key_tuple[-1]: + flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",) + elif "var" in flax_key_tuple[-1]: + flax_key_tuple = flax_key_tuple[:-1] + ("running_var",) + + if "batch_stats" in flax_state: + flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header + else: + flax_key = ".".join(flax_key_tuple) + + # We also need to look at `pt_model_dict` and see if there are keys requiring further transformation. + special_pt_names = {} + # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 + for key in pt_model_dict: + key_components = key.split(".") + name = None + if key_components[-3::2] == ["parametrizations", "original0"]: + name = key_components[-2] + "_g" + elif key_components[-3::2] == ["parametrizations", "original1"]: + name = key_components[-2] + "_v" + if name is not None: + key_components = key_components[:-3] + [name] + key_to_check = ".".join(key_components) + special_pt_names[key_to_check] = key + + if flax_key in special_pt_names: + flax_key = special_pt_names[flax_key] + + if flax_key in pt_model_dict: + if flax_tensor.shape != pt_model_dict[flax_key].shape: + raise ValueError( + f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " + f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + else: + # add weight to pytorch dict + flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor + pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) + # remove from missing keys + missing_keys.remove(flax_key) + else: + # weight is not expected by PyTorch model + unexpected_keys.append(flax_key) + + pt_model.load_state_dict(pt_model_dict) + + # re-transform missing_keys to list + missing_keys = list(missing_keys) + + if len(unexpected_keys) > 0: + logger.warning( + "Some weights of the Flax model were not used when initializing the PyTorch model" + f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" + f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" + " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" + f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" + " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" + " FlaxBertForSequenceClassification model)." + ) + else: + logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" + f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" + " use it for predictions and inference." + ) + else: + logger.warning( + f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n" + "If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {pt_model.__class__.__name__} for predictions without further training." + ) + + return pt_model + + # should not include what is already done by the `from_pt` argument MOE_LAYER_NAME_MAPPING = { "/attention/": "/0/SelfAttention/",